From b1ac2ddb2c5ab232b8286b36c0e2290f353ffbc2 Mon Sep 17 00:00:00 2001 From: christ-tt Date: Tue, 21 Oct 2025 21:50:31 +0800 Subject: [PATCH 01/11] feat: real cont. batching --- src/parallax/server/executor.py | 12 +- src/parallax/server/request.py | 1 + src/parallax/server/scheduler.py | 198 +++++++++++++++++++++---------- tests/test_batch_scheduler.py | 111 +++++++++++++++++ tests/test_executor.py | 3 + 5 files changed, 254 insertions(+), 71 deletions(-) create mode 100644 tests/test_batch_scheduler.py diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 1da3df2e..9c77f15c 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -242,6 +242,7 @@ def __init__( micro_batch_ratio=micro_batch_ratio, is_first_peer=self.is_first_peer, tokenizer=self.tokenizer, + kv_cache_manager=self.kv_cache_manager if self.device == "mlx" else None, ) logger.debug( f"Scheduler initialized (max_batch_size={max_batch_size}, max_tokens={max_num_tokens_per_batch}, wait_ms={scheduler_wait_ms})" @@ -758,6 +759,7 @@ def _handle_cuda_input_requests(self, requests: List[Request]): if not self.is_last_peer: self.finished_batch.append(req) else: + # Mark ready for next decode step on first peer self.scheduler.enque_request(original_req) # detokenize and send to http server @@ -780,7 +782,7 @@ def _handle_cuda_input_requests(self, requests: List[Request]): req, IntermediateRequest ), "Non-first peers must receive IntermediateRequests." if req.is_finished or req.hidden_states is None: - self.scheduler.evict_request(req.request_id, req.status) + self.scheduler.evict_request(req.request_id) release_cuda_request(self.running_batch, req.request_id) if not self.is_last_peer: self.finished_batch.append(req) @@ -804,8 +806,6 @@ def _handle_input_requests(self, requests: List[Request]): # or IntermediateRequests from the last peer. for req in requests: if isinstance(req, InitialRequest): - if not self.kv_cache_manager.has_request(req.request_id): - self.kv_cache_manager.add_request(req, req.total_length) self.scheduler.enque_request(req) elif isinstance(req, IntermediateRequest): original_req = self.scheduler.get_running_request(req.request_id) @@ -874,14 +874,12 @@ def _handle_input_requests(self, requests: List[Request]): f"kv cache manager has {self.kv_cache_manager.tokens_in_cache} tokens, " f"memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB" ) - self.scheduler.evict_request(req.request_id, req.status) + self.scheduler.evict_request(req.request_id) if not self.is_last_peer: self.finished_batch.append(req) else: # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) - if not self.kv_cache_manager.has_request(req.request_id): - self.kv_cache_manager.add_request(req, req.total_length) def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Request: """Handle request state changes both inter and intra peers. @@ -1204,7 +1202,7 @@ def run_loop(self): logger.exception(f"Error processing batch: {e}") # Naive error handling: release and evict all requests in the batch for req in batch_to_process: - self.scheduler.evict_request(req.request_id, req.status) + self.scheduler.evict_request(req.request_id) if self.device == "cuda": from parallax.sglang.batch_info import release_cuda_request diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index f31b4959..7a073262 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -103,6 +103,7 @@ def __init__( self.routing_table = routing_table self.sampling_params = sampling_params or SamplingParams() self.abort = False + self.ready_for_next_step = False @property def is_finished(self) -> bool: diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index fd500b81..8d32b492 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -1,15 +1,29 @@ """ -Scheduling requests to form batches.sche - -A scheduler will maintain a Priority Queue for request waiting pool. -We support continuous batching, and similar to TensorRT-LLM, - we favors prefill requests over decode requests. +Continuous Batching Scheduler. + +State managed by the scheduler: + 1. Prefill Wait Queue (FIFO): incoming prefill requests waiting for admission; + 2. Running Requests: inflight requests with KV-cache residency; + 3. Active Batch: the concrete batch chosen for the next model forward. + +We use an explicit 2-Phase approach: + * Phase 1 (Admission): wait queue -> running requests + Implemented by `admit_requests`. We admit requests when capacity + allows (e.g., max concurrent requests, memory availability). Admitted + requests get KV-cache residency and become inflight. + * Phase 2 (Batching): running requests -> active batch for actual forward + Implemented by `form_batch`. We prioritize PREFILL requests + first within `max_num_tokens_per_batch` and `micro_batch_size`, + then include DECODE requests that are marked ready for the next decode step. + +Our scheduler also handles tokenization and pre-processing for the First Peer's requests. """ -import heapq import time -from typing import Dict, List, Literal, Optional, Tuple +from collections import OrderedDict +from typing import Dict, List, Optional +from parallax.server.kv_cache import KVCacheManager from parallax.server.metrics import update_metrics from parallax.server.request import InitialRequest, Request, RequestStatus from parallax_utils.logging_config import get_logger @@ -19,29 +33,29 @@ class Scheduler: """ - A simple scheduler to manage requests and form them into batches. - This scheduler is designed to handle requests in a FIFO manner. + 2-Phase approach: + * Phase 1: wait queue -> running requests (all inflight requests) + * Phase 2: running requests -> active batch (actual model forward) """ def __init__( self, max_batch_size: int = 16, max_num_tokens_per_batch: int = 4096, - prefill_priority: Literal[0, 1] = 0, scheduler_wait_ms: int = 200, micro_batch_ratio: int = 2, is_first_peer: bool = False, + kv_cache_manager: Optional[KVCacheManager] = None, **kwargs, ): """ Args: - max_batch_size: Maximum number of running requests; + max_batch_size: Maximum number of running / inflight requests; max_num_tokens_per_batch: Maxmimum number of prefill + decode tokens in a single batch; - prefill_priority: Priority for prefill requests, - default 0 for prefill, 1 for decode, 0 for higher priority; scheduler_wait_ms: The minimum time to wait before dispatching a batch; - micro_batch_ratio: micro_batch_size = max_batch_size // micro_batch_ratio - tokenizer: The tokenizer to use for the model. + micro_batch_ratio: micro_batch_size = max_batch_size // micro_batch_ratio; + tokenizer: The tokenizer to use for the model; + kv_cache_manager: The KV cache manager to use for the scheduler. """ self.max_batch_size = max_batch_size self.max_num_tokens_per_batch = max_num_tokens_per_batch @@ -55,15 +69,15 @@ def __init__( self.max_new_tokens = kwargs.get("max_new_tokens", 512) self.max_total_length = kwargs.get("max_total_length", 1024) - # Priority queue: (priority, arrival_time, request_id, request_object) - self._request_queue: List[Tuple[int, float, str, Request]] = [] + # Prefill wait queue (FIFO) for admission; supports moving chunked prefill to front + self._wait_queue: List[Request] = [] # Keeps track of all in-flight requests - self._running_requests: Dict[str, Request] = {} + self._running_requests: Dict[str, Request] = OrderedDict() + # The actual batch of requests for model forward runner + self._active_batch: Dict[str, Request] = {} + + self.kv_cache_manager = kv_cache_manager - self.priority_map = { - RequestStatus.PREFILLING: prefill_priority, - RequestStatus.DECODING: 1 - prefill_priority, - } self._last_dispatch_ts = time.time() # Track last reported running requests to avoid redundant metric updates self._last_reported_running_requests: int = 0 @@ -75,7 +89,7 @@ def __init__( @property def num_queued_requests(self) -> int: """Get the number of requests in the scheduler.""" - return len(self._request_queue) + return len(self._wait_queue) @property def num_running_requests(self) -> int: @@ -85,7 +99,7 @@ def num_running_requests(self) -> int: @property def has_pending_requests(self) -> bool: """Check if there are any pending requests in the scheduler.""" - return len(self._request_queue) > 0 + return len(self._wait_queue) > 0 def get_running_request(self, request_id: str) -> Optional[Request]: """Gets a request that is currently in the running state.""" @@ -100,7 +114,7 @@ def _prompt_string_to_request(self, request_str: str) -> InitialRequest: ) def enque_request(self, request: Request | str): - """Add a request to the scheduler.""" + """Enque a request to the scheduler's wait queue.""" if isinstance(request, str): request = self._prompt_string_to_request(request) @@ -110,25 +124,35 @@ def enque_request(self, request: Request | str): f"{request.status}. Not adding to the scheduler." ) return - arrival_time = time.time() - priority = self.priority_map.get(request.status, 1) - heapq.heappush(self._request_queue, (priority, arrival_time, request.request_id, request)) - logger.debug(f"Request {request.request_id} added to the scheduler.") - logger.debug(f"Scheduler queue size: {len(self._request_queue)}") - # Running count does not change on enqueue; do not update metrics here - - def evict_request(self, request_id: str, status: Optional[RequestStatus] = None): + + # TODO: Handle chunked prefill. + if request.is_decoding: + rid = request.request_id + if rid not in self._running_requests: + raise ValueError( + f"Decode request {rid} must already be admitted (in running requests)." + ) + # Mark as ready and update recency ordering so earlier-ready decodes + # are encountered first during actual batch formation + self._running_requests.move_to_end(rid) + logger.debug(f"Decode request {rid} marked ready for next decode.") + return + + self._wait_queue.append(request) + request.ready_for_next_step = True + logger.debug( + f"Prefill request {request.request_id} added to the prefill wait queue (size={len(self._wait_queue)})." + ) + + def evict_request(self, request_id: str): """Removes a request from the scheduler's running queue.""" - _ = status # status is used by the first peer's logic but not here. if request_id in self._running_requests: self._running_requests.pop(request_id) logger.debug(f"Evicted request {request_id} from scheduler.") # Update metrics only if running count changed since last report try: curr = self.num_running_requests - # if curr != self._last_reported_running_requests: update_metrics(current_requests=curr) - # self._last_reported_running_requests = curr except Exception: pass else: @@ -179,48 +203,94 @@ def should_dispatch(self) -> bool: queued = self.num_queued_requests >= self.micro_batch_size return waited or queued + def admit_requests(self) -> None: + """Move requests from wait queue into running (inflight) set, up to capacity. + + Pushes admitted requests directly into the running set. + """ + while self._wait_queue and len(self._running_requests) < self.max_batch_size: + req = self._wait_queue.pop(0) + rid = req.request_id + if rid in self._running_requests: + # Already inflight; chunked-prefill, skip + continue + # Check kv cache pool + if self.kv_cache_manager is not None: + if not self.kv_cache_manager.has_request(req.request_id): + if not self.kv_cache_manager.add_request(req, req.total_length): + logger.warning( + f"Request {rid} can't be admit to running batch due to KV cache size." + ) + continue + self._running_requests[rid] = req + + # Reflect current running requests metric after admission + try: + curr = self.num_running_requests + if curr != self._last_reported_running_requests: + update_metrics(current_requests=curr) + self._last_reported_running_requests = curr + except Exception: + pass + + self._last_dispatch_ts = time.time() + return None + def form_batch(self) -> List[Request]: - """Get the next batch of requests. + """Form the active batch for the next forward pass. - At-most `micro_batch_size` requests will be returned. + - Select prefills first (FIFO by admission), then decodes that are ready + following the OrderedDict iteration order where ready decodes are + moved-to-end upon readiness, while respecting micro_batch_size and + max_num_tokens_per_batch. """ - if not self.has_pending_requests: + # TODO: we need to fully decouple admit_requests and form_batch + # to overlap micro-batch scheduling with both model running & communication to other peers. + self.admit_requests() + if not self._running_requests: return [] inflight_tokens = 0 - - batch = [] - save_index = [] - for index, request in enumerate(self._request_queue): + batch: List[Request] = [] + + # Prefill candidates: preserve admission order via OrderedDict iteration + prefill_candidates: List[Request] = [ + req for req in self._running_requests.values() if req.is_prefill + ] + + # Decode candidates: only those ready, maintain OrderedDict order which was + # updated upon readiness (earlier-ready decodes appear earlier) + decode_ready_candidates: List[Request] = [ + req + for req in self._running_requests.values() + if req.is_decoding and req.ready_for_next_step + ] + + # 1) Fill with prefills first + for req in prefill_candidates: if len(batch) >= self.micro_batch_size: - save_index.append(index) + break + cost = req.prompt_len + if cost + inflight_tokens > self.max_num_tokens_per_batch: continue - _, _, rid, req = request + batch.append(req) + inflight_tokens += cost - cost = req.prompt_len if req.is_prefill else 1 + # 2) Fill remaining with ready decodes + for req in decode_ready_candidates: + if len(batch) >= self.micro_batch_size: + break + cost = 1 if cost + inflight_tokens > self.max_num_tokens_per_batch: - save_index.append(index) continue - - if rid not in self._running_requests: - if len(self._running_requests) >= self.max_batch_size: - save_index.append(index) - continue - batch.append(req) - self._running_requests[rid] = req - inflight_tokens += cost - self._request_queue = [self._request_queue[i] for i in save_index] + # Track the active batch mapping for introspection / downstream usage + self._active_batch = {r.request_id: r for r in batch} - # Reflect current running requests metric after forming the batch - try: - curr = self.num_running_requests - if curr != self._last_reported_running_requests: - update_metrics(current_requests=curr) - self._last_reported_running_requests = curr - except Exception: - pass + # Clear ready flags for decodes included in this batch + for r in batch: + r.ready_for_next_step = False return batch diff --git a/tests/test_batch_scheduler.py b/tests/test_batch_scheduler.py new file mode 100644 index 00000000..ca7a6546 --- /dev/null +++ b/tests/test_batch_scheduler.py @@ -0,0 +1,111 @@ +import pytest + +from parallax.server.request import InitialRequest, Request, RequestStatus +from parallax.server.scheduler import Scheduler + + +class FakeKVCacheManager: + def __init__(self, allow: bool = True): + self.allow = allow + self._reqs = set() + + def has_request(self, request_id: str) -> bool: + return request_id in self._reqs + + def add_request(self, request: Request, num_tokens: int = 0) -> bool: + if not self.allow: + return False + self._reqs.add(request.request_id) + return True + + +def make_prefill(rid: str, prompt_len: int) -> InitialRequest: + return InitialRequest(request_id=rid, input_ids=[0] * prompt_len) + + +def make_decode(rid: str) -> Request: + return Request(request_id=rid, status=RequestStatus.DECODING) + + +def test_prefill_fifo_and_micro_batch(): + sched = Scheduler(max_batch_size=8, max_num_tokens_per_batch=10_000, micro_batch_ratio=1) + # micro_batch_size = max_batch_size // ratio = 8 + # Enqueue 3 prefills in order + r1 = make_prefill("r1", 5) + r2 = make_prefill("r2", 6) + r3 = make_prefill("r3", 7) + sched.enque_request(r1) + sched.enque_request(r2) + sched.enque_request(r3) + + batch = sched.form_batch() + ids = [r.request_id for r in batch] + assert ids[:3] == ["r1", "r2", "r3"] + + +def test_decode_ready_order_and_prefill_first(): + # micro_batch_size = 3 + sched = Scheduler(max_batch_size=3, max_num_tokens_per_batch=10_000, micro_batch_ratio=1) + + # Two decodes already running + d1 = make_decode("d1") + d2 = make_decode("d2") + sched._running_requests[d1.request_id] = d1 + sched._running_requests[d2.request_id] = d2 + + # One prefill in queue + p1 = make_prefill("p1", 8) + sched.enque_request(p1) + + # Mark d1 ready first, then d2 + sched.enque_request(d1) # sets ready_for_next_step + LRU move_to_end + sched.enque_request(d2) + + sched.admit_requests() + batch = sched.form_batch() + ids = [r.request_id for r in batch] + + # Prefill first, then decodes in the order they became ready + assert ids == ["p1", "d1", "d2"] + + +def test_token_budget_prefill_skipped_decode_taken(): + # Token budget too small for prefill, but enough for decodes (cost=1) + sched = Scheduler(max_batch_size=2, max_num_tokens_per_batch=1, micro_batch_ratio=1) + + # One large prefill + p_big = make_prefill("p_big", 5) + sched.enque_request(p_big) + + # One ready decode already running + d = make_decode("d") + sched._running_requests[d.request_id] = d + sched.enque_request(d) + + sched.admit_requests() + batch = sched.form_batch() + ids = [r.request_id for r in batch] + assert ids == ["d"] + # ready flag should be reset after batching + assert getattr(d, "ready_for_next_step", False) is False + + +def test_kv_cache_admission_guard_blocks_prefill(): + # A KV manager that rejects additions + kv_mgr = FakeKVCacheManager(allow=False) + sched = Scheduler( + max_batch_size=2, + max_num_tokens_per_batch=100, + micro_batch_ratio=1, + kv_cache_manager=kv_mgr, + ) + p = make_prefill("p", 4) + sched.enque_request(p) + + # Admission should fail and running set remains empty; batch should be empty + sched.admit_requests() + batch = sched.form_batch() + assert len(batch) == 0 + assert sched.num_running_requests == 0 + + diff --git a/tests/test_executor.py b/tests/test_executor.py index 1dd16093..4c880d8e 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -52,7 +52,10 @@ def test_decode_pipeline_multiple_steps(start_layer, end_layer, num_decode_steps prefill_inputs_p1 = executor_peer1._prepare_batch_inputs(prefill_batch_p1) assert prefill_inputs_p1 is not None, "Failed to prepare batch inputs" prefill_batch_data = prefill_inputs_p1["prefill_batch"] + print("Pre Process") hidden_states_p1 = executor_peer1.process_batch(prefill_batch_data, return_decoded_tokens=False) + print(hidden_states_p1) + print("Process done") prefill_reqs_p2 = executor_peer1._prepare_next_batch_requests( requests=prefill_batch_data["requests"], hidden_states=hidden_states_p1, From 7bfb7b9da53eed4a9642534f47ca9eb2f43d0b87 Mon Sep 17 00:00:00 2001 From: christ-tt Date: Tue, 21 Oct 2025 22:00:31 +0800 Subject: [PATCH 02/11] fix: batch unit test --- tests/test_batch_scheduler.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/test_batch_scheduler.py b/tests/test_batch_scheduler.py index ca7a6546..80fa7e9c 100644 --- a/tests/test_batch_scheduler.py +++ b/tests/test_batch_scheduler.py @@ -1,5 +1,3 @@ -import pytest - from parallax.server.request import InitialRequest, Request, RequestStatus from parallax.server.scheduler import Scheduler @@ -23,8 +21,10 @@ def make_prefill(rid: str, prompt_len: int) -> InitialRequest: return InitialRequest(request_id=rid, input_ids=[0] * prompt_len) -def make_decode(rid: str) -> Request: - return Request(request_id=rid, status=RequestStatus.DECODING) +def make_decode(rid: str, ready: bool = True) -> Request: + r = Request(request_id=rid, status=RequestStatus.DECODING) + r.ready_for_next_step = ready + return r def test_prefill_fifo_and_micro_batch(): @@ -82,7 +82,6 @@ def test_token_budget_prefill_skipped_decode_taken(): sched._running_requests[d.request_id] = d sched.enque_request(d) - sched.admit_requests() batch = sched.form_batch() ids = [r.request_id for r in batch] assert ids == ["d"] @@ -103,9 +102,6 @@ def test_kv_cache_admission_guard_blocks_prefill(): sched.enque_request(p) # Admission should fail and running set remains empty; batch should be empty - sched.admit_requests() batch = sched.form_batch() assert len(batch) == 0 assert sched.num_running_requests == 0 - - From 9cc09e6c6b804cba42aa7c35b5f2c9bf4cd4228c Mon Sep 17 00:00:00 2001 From: christ-tt Date: Tue, 21 Oct 2025 22:27:13 +0800 Subject: [PATCH 03/11] fix: remove active batch and fix comments --- src/parallax/server/executor.py | 1 - src/parallax/server/scheduler.py | 7 +------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 9c77f15c..381a7a93 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -759,7 +759,6 @@ def _handle_cuda_input_requests(self, requests: List[Request]): if not self.is_last_peer: self.finished_batch.append(req) else: - # Mark ready for next decode step on first peer self.scheduler.enque_request(original_req) # detokenize and send to http server diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 8d32b492..e3f08935 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -69,12 +69,10 @@ def __init__( self.max_new_tokens = kwargs.get("max_new_tokens", 512) self.max_total_length = kwargs.get("max_total_length", 1024) - # Prefill wait queue (FIFO) for admission; supports moving chunked prefill to front + # Prefill wait queue (FIFO) for admission self._wait_queue: List[Request] = [] # Keeps track of all in-flight requests self._running_requests: Dict[str, Request] = OrderedDict() - # The actual batch of requests for model forward runner - self._active_batch: Dict[str, Request] = {} self.kv_cache_manager = kv_cache_manager @@ -286,9 +284,6 @@ def form_batch(self) -> List[Request]: batch.append(req) inflight_tokens += cost - # Track the active batch mapping for introspection / downstream usage - self._active_batch = {r.request_id: r for r in batch} - # Clear ready flags for decodes included in this batch for r in batch: r.ready_for_next_step = False From 0b3b897ad3071f240f15cd1e958a0fd2d69024f3 Mon Sep 17 00:00:00 2001 From: christ-tt Date: Tue, 21 Oct 2025 23:16:26 +0800 Subject: [PATCH 04/11] fix: decode request when enque --- src/parallax/server/scheduler.py | 7 ++++--- tests/test_executor.py | 3 --- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index e3f08935..b17daebe 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -123,6 +123,7 @@ def enque_request(self, request: Request | str): ) return + request.ready_for_next_step = True # TODO: Handle chunked prefill. if request.is_decoding: rid = request.request_id @@ -130,14 +131,14 @@ def enque_request(self, request: Request | str): raise ValueError( f"Decode request {rid} must already be admitted (in running requests)." ) - # Mark as ready and update recency ordering so earlier-ready decodes - # are encountered first during actual batch formation + # Merge incoming decode readiness/state into the existing running request + self._running_requests[rid] = request + # Update recency ordering so earlier-ready decodes are encountered first during batching self._running_requests.move_to_end(rid) logger.debug(f"Decode request {rid} marked ready for next decode.") return self._wait_queue.append(request) - request.ready_for_next_step = True logger.debug( f"Prefill request {request.request_id} added to the prefill wait queue (size={len(self._wait_queue)})." ) diff --git a/tests/test_executor.py b/tests/test_executor.py index 4c880d8e..1dd16093 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -52,10 +52,7 @@ def test_decode_pipeline_multiple_steps(start_layer, end_layer, num_decode_steps prefill_inputs_p1 = executor_peer1._prepare_batch_inputs(prefill_batch_p1) assert prefill_inputs_p1 is not None, "Failed to prepare batch inputs" prefill_batch_data = prefill_inputs_p1["prefill_batch"] - print("Pre Process") hidden_states_p1 = executor_peer1.process_batch(prefill_batch_data, return_decoded_tokens=False) - print(hidden_states_p1) - print("Process done") prefill_reqs_p2 = executor_peer1._prepare_next_batch_requests( requests=prefill_batch_data["requests"], hidden_states=hidden_states_p1, From 3e0b5eefdf91e1ad8e66c1597f895357b95745e8 Mon Sep 17 00:00:00 2001 From: christ-tt Date: Wed, 22 Oct 2025 11:08:29 +0800 Subject: [PATCH 05/11] fix: default micro batch size --- src/parallax/server/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index b17daebe..ed3a4031 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -4,7 +4,7 @@ State managed by the scheduler: 1. Prefill Wait Queue (FIFO): incoming prefill requests waiting for admission; 2. Running Requests: inflight requests with KV-cache residency; - 3. Active Batch: the concrete batch chosen for the next model forward. +Main `form_batch` function will return the concrete batch chosen for the next model forward. We use an explicit 2-Phase approach: * Phase 1 (Admission): wait queue -> running requests @@ -43,7 +43,7 @@ def __init__( max_batch_size: int = 16, max_num_tokens_per_batch: int = 4096, scheduler_wait_ms: int = 200, - micro_batch_ratio: int = 2, + micro_batch_ratio: int = 1, is_first_peer: bool = False, kv_cache_manager: Optional[KVCacheManager] = None, **kwargs, From 9f40a76efcccbd92217cb2ed3205a46a5aff4574 Mon Sep 17 00:00:00 2001 From: christ-tt Date: Wed, 29 Oct 2025 23:05:32 +0800 Subject: [PATCH 06/11] fix: remove explicit wait --- src/parallax/server/executor.py | 10 ++----- src/parallax/server/scheduler.py | 51 +++++++++++++------------------- 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 14702078..fa6aaf0e 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -1055,7 +1055,7 @@ def _process_batch_mlx( ] # k_caches shape: (num_layers, B, num_kv_heads, L_padded, head_dim) logger.debug( - f"Processed batch with {len(prepared_inputs['requests'])} requests, " + f"Processing batch with {len(prepared_inputs['requests'])} requests, " f"request status: {prepared_inputs['requests'][0].status}, " f"hidden_states shape: {hidden_states.shape}, " f"k_caches shape: {k_caches.shape}, " @@ -1136,12 +1136,8 @@ def run_loop(self): ) self.finished_batch = [] - # 4. Check if we should form a batch - if not self.scheduler.should_dispatch(): - time.sleep(0.01) # prevent busy waiting - continue - - # 5. Form a batch from the scheduler's queue + # 4. Admit requests into running set up to capacity, then form batch + self.scheduler.admit_requests() batch_to_process = self.scheduler.form_batch() if not batch_to_process: continue diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index ed3a4031..042761fd 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -43,7 +43,7 @@ def __init__( max_batch_size: int = 16, max_num_tokens_per_batch: int = 4096, scheduler_wait_ms: int = 200, - micro_batch_ratio: int = 1, + micro_batch_ratio: int = 2, is_first_peer: bool = False, kv_cache_manager: Optional[KVCacheManager] = None, **kwargs, @@ -94,11 +94,6 @@ def num_running_requests(self) -> int: """Get the number of requests currently being processed.""" return len(self._running_requests) - @property - def has_pending_requests(self) -> bool: - """Check if there are any pending requests in the scheduler.""" - return len(self._wait_queue) > 0 - def get_running_request(self, request_id: str) -> Optional[Request]: """Gets a request that is currently in the running state.""" return self._running_requests.get(request_id) @@ -196,13 +191,7 @@ def check_and_update_request_status(self, request: InitialRequest) -> bool: return finished - def should_dispatch(self) -> bool: - """Helper check if the scheduler should dispatch a batch.""" - waited = (time.time() - self._last_dispatch_ts) * 1000 >= self.scheduler_wait_ms - queued = self.num_queued_requests >= self.micro_batch_size - return waited or queued - - def admit_requests(self) -> None: + def admit_requests(self): """Move requests from wait queue into running (inflight) set, up to capacity. Pushes admitted requests directly into the running set. @@ -222,6 +211,9 @@ def admit_requests(self) -> None: ) continue self._running_requests[rid] = req + logger.debug( + f"Admitted to running: rid={rid}, status={req.status}, running_size={len(self._running_requests)}, ready={req.ready_for_next_step}" + ) # Reflect current running requests metric after admission try: @@ -232,8 +224,7 @@ def admit_requests(self) -> None: except Exception: pass - self._last_dispatch_ts = time.time() - return None + return def form_batch(self) -> List[Request]: """Form the active batch for the next forward pass. @@ -243,9 +234,6 @@ def form_batch(self) -> List[Request]: moved-to-end upon readiness, while respecting micro_batch_size and max_num_tokens_per_batch. """ - # TODO: we need to fully decouple admit_requests and form_batch - # to overlap micro-batch scheduling with both model running & communication to other peers. - self.admit_requests() if not self._running_requests: return [] @@ -253,17 +241,14 @@ def form_batch(self) -> List[Request]: batch: List[Request] = [] # Prefill candidates: preserve admission order via OrderedDict iteration - prefill_candidates: List[Request] = [ - req for req in self._running_requests.values() if req.is_prefill - ] - - # Decode candidates: only those ready, maintain OrderedDict order which was - # updated upon readiness (earlier-ready decodes appear earlier) - decode_ready_candidates: List[Request] = [ - req - for req in self._running_requests.values() - if req.is_decoding and req.ready_for_next_step - ] + prefill_candidates = [] + decode_candidates = [] + for req in self._running_requests.values(): + if req.ready_for_next_step: + if req.is_prefill: + prefill_candidates.append(req) + elif req.is_decoding: + decode_candidates.append(req) # 1) Fill with prefills first for req in prefill_candidates: @@ -276,7 +261,7 @@ def form_batch(self) -> List[Request]: inflight_tokens += cost # 2) Fill remaining with ready decodes - for req in decode_ready_candidates: + for req in decode_candidates: if len(batch) >= self.micro_batch_size: break cost = 1 @@ -289,4 +274,10 @@ def form_batch(self) -> List[Request]: for r in batch: r.ready_for_next_step = False + if batch: + logger.debug( + "Form batch selected=%s inflight_tokens=%d", + [f"{r.request_id}:{r.status}, ready:{r.ready_for_next_step}" for r in batch], + inflight_tokens, + ) return batch From 8e8f1740edbb73670527ad84ddcc3a2ff9abe1eb Mon Sep 17 00:00:00 2001 From: christ-tt Date: Thu, 30 Oct 2025 12:14:48 +0800 Subject: [PATCH 07/11] feat: request time out in batch scheduling --- src/parallax/server/executor.py | 41 ++++++++++++++++++++++++++++++ src/parallax/server/request.py | 1 + src/parallax/server/scheduler.py | 24 +++++++++++++++++ src/parallax/server/server_args.py | 10 ++++++++ 4 files changed, 76 insertions(+) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index fa6aaf0e..49a25534 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -80,6 +80,7 @@ def __init__( prefill_priority: int = 0, micro_batch_ratio: int = 2, scheduler_wait_ms: int = 500, + request_timeout_s: Optional[int] = 600, # Metrics Configs layer_latency_update_every: int = 4096, # KV Cache Configs @@ -243,6 +244,7 @@ def __init__( is_first_peer=self.is_first_peer, tokenizer=self.tokenizer, kv_cache_manager=self.kv_cache_manager if self.device == "mlx" else None, + request_timeout_s=request_timeout_s, ) logger.debug( f"Scheduler initialized (max_batch_size={max_batch_size}, max_tokens={max_num_tokens_per_batch}, wait_ms={scheduler_wait_ms})" @@ -1138,6 +1140,45 @@ def run_loop(self): # 4. Admit requests into running set up to capacity, then form batch self.scheduler.admit_requests() + # 4.1 Check for request timeouts and abort timed out requests + try: + timed_out_reqs = self.scheduler.get_timed_out_requests() + if timed_out_reqs: + for req in timed_out_reqs: + rid = req.request_id + logger.warning( + f"Request {rid} exceeded timeout ({req.timeout_s}s). Aborting and releasing resources." + ) + # Release resources + if self.device == "cuda": + from parallax.sglang.batch_info import release_cuda_request + + try: + release_cuda_request(self.running_batch, rid) + except Exception: + pass + else: + try: + if ( + hasattr(self, "kv_cache_manager") + and self.kv_cache_manager is not None + ): + self.kv_cache_manager.release_request(rid) + except Exception: + pass + + # Evict from scheduler + try: + self.scheduler.evict_request(rid) + except Exception: + pass + + # Notify downstream peers to abort if this peer is the first peer in a pipeline + if self.is_first_peer and not self.is_last_peer: + self.finished_batch.append(req) + except Exception: + # Non-fatal; continue serving + pass batch_to_process = self.scheduler.form_batch() if not batch_to_process: continue diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index 7a073262..a34ac1b4 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -104,6 +104,7 @@ def __init__( self.sampling_params = sampling_params or SamplingParams() self.abort = False self.ready_for_next_step = False + self.start_time: Optional[float] = None @property def is_finished(self) -> bool: diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 042761fd..3929c6c4 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -46,6 +46,7 @@ def __init__( micro_batch_ratio: int = 2, is_first_peer: bool = False, kv_cache_manager: Optional[KVCacheManager] = None, + request_timeout_s: Optional[int] = 600, **kwargs, ): """ @@ -56,6 +57,7 @@ def __init__( micro_batch_ratio: micro_batch_size = max_batch_size // micro_batch_ratio; tokenizer: The tokenizer to use for the model; kv_cache_manager: The KV cache manager to use for the scheduler. + request_timeout_s: timeout for each inflight request (default 10mins). """ self.max_batch_size = max_batch_size self.max_num_tokens_per_batch = max_num_tokens_per_batch @@ -75,6 +77,8 @@ def __init__( self._running_requests: Dict[str, Request] = OrderedDict() self.kv_cache_manager = kv_cache_manager + # Default timeout for requests if not set on request object + self.request_timeout_s = request_timeout_s self._last_dispatch_ts = time.time() # Track last reported running requests to avoid redundant metric updates @@ -211,6 +215,8 @@ def admit_requests(self): ) continue self._running_requests[rid] = req + # Initialize timing for timeout enforcement + req.start_time = time.time() logger.debug( f"Admitted to running: rid={rid}, status={req.status}, running_size={len(self._running_requests)}, ready={req.ready_for_next_step}" ) @@ -226,6 +232,24 @@ def admit_requests(self): return + def get_timed_out_requests(self) -> List[Request]: + """Return running requests that exceeded their timeout and mark them aborted. + + This does not evict or release resources; callers must handle cleanup. + """ + timed_out: List[Request] = [] + now = time.time() + for req in list(self._running_requests.values()): + try: + if req.start_time is None: + raise ValueError("Requests should have start time set.") + if now - req.start_time > self.timeout_s: + req.abort = True + timed_out.append(req) + except Exception: + continue + return timed_out + def form_batch(self) -> List[Request]: """Form the active batch for the next forward pass. diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 5b1429e0..a0165282 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -126,6 +126,13 @@ def parse_args() -> argparse.Namespace: "--scheduler-wait-ms", type=int, default=500, help="Scheduler wait time in milliseconds" ) + parser.add_argument( + "--request-timeout-s", + type=int, + default=300, + help="Per-request timeout in seconds before automatic abort", + ) + # GPU/SGLang specialized configuration parser.add_argument( "--attention-backend", @@ -213,6 +220,9 @@ def validate_args(args: argparse.Namespace) -> None: if args.scheduler_wait_ms < 0: raise ValueError("scheduler_wait_ms must be non-negative") + if getattr(args, "request_timeout_s", 300) is not None and args.request_timeout_s <= 0: + raise ValueError("request_timeout_s must be positive") + # Validate supported dtypes dtype_list = [ "float16", From 53057674cc14e47a8575860946ee43f2e0119605 Mon Sep 17 00:00:00 2001 From: christ-tt Date: Fri, 31 Oct 2025 04:02:20 +0800 Subject: [PATCH 08/11] fix: bind admit and form batch --- src/parallax/server/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 042761fd..8bd386b6 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -234,6 +234,7 @@ def form_batch(self) -> List[Request]: moved-to-end upon readiness, while respecting micro_batch_size and max_num_tokens_per_batch. """ + self.admit_requests() if not self._running_requests: return [] From 9bfe135bcf9d0f7c51e2442e696060acec8a38dd Mon Sep 17 00:00:00 2001 From: christ-tt Date: Fri, 31 Oct 2025 04:10:38 +0800 Subject: [PATCH 09/11] fix: server args --- src/parallax/server/server_args.py | 4 ++-- tests/test_executor.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index a0165282..9686bf9f 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -129,7 +129,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--request-timeout-s", type=int, - default=300, + default=600, help="Per-request timeout in seconds before automatic abort", ) @@ -220,7 +220,7 @@ def validate_args(args: argparse.Namespace) -> None: if args.scheduler_wait_ms < 0: raise ValueError("scheduler_wait_ms must be non-negative") - if getattr(args, "request_timeout_s", 300) is not None and args.request_timeout_s <= 0: + if getattr(args, "request_timeout_s", None) is not None and args.request_timeout_s <= 0: raise ValueError("request_timeout_s must be positive") # Validate supported dtypes diff --git a/tests/test_executor.py b/tests/test_executor.py index 1dd16093..bbb117b4 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -48,6 +48,7 @@ def test_decode_pipeline_multiple_steps(start_layer, end_layer, num_decode_steps ] executor_peer1._handle_input_requests(initial_requests) + executor_peer1.scheduler.admit_requests() prefill_batch_p1 = executor_peer1.scheduler.form_batch() prefill_inputs_p1 = executor_peer1._prepare_batch_inputs(prefill_batch_p1) assert prefill_inputs_p1 is not None, "Failed to prepare batch inputs" @@ -62,6 +63,7 @@ def test_decode_pipeline_multiple_steps(start_layer, end_layer, num_decode_steps # send to next peer executor_peer2._handle_input_requests(prefill_reqs_p2) + executor_peer2.scheduler.admit_requests() prefill_batch_p2 = executor_peer2.scheduler.form_batch() prefill_inputs_p2 = executor_peer2._prepare_batch_inputs(prefill_batch_p2) assert prefill_inputs_p2 is not None, "Failed to prepare batch inputs" @@ -82,6 +84,7 @@ def test_decode_pipeline_multiple_steps(start_layer, end_layer, num_decode_steps executor_peer1._handle_input_requests(feedback_reqs) # 3. Peer 1: form and process decode batch + executor_peer1.scheduler.admit_requests() decode_batch_p1 = executor_peer1.scheduler.form_batch() decode_inputs_p1 = executor_peer1._prepare_batch_inputs(decode_batch_p1) assert decode_inputs_p1 is not None, "Failed to prepare batch inputs" @@ -99,6 +102,7 @@ def test_decode_pipeline_multiple_steps(start_layer, end_layer, num_decode_steps # 5. Peer 2: process decode batch to get next tokens executor_peer2._handle_input_requests(decode_reqs_p2) + executor_peer2.scheduler.admit_requests() decode_batch_p2 = executor_peer2.scheduler.form_batch() decode_inputs_p2 = executor_peer2._prepare_batch_inputs(decode_batch_p2) assert decode_inputs_p2 is not None, "Failed to prepare batch inputs" From 0655763040836447d05c141a4ef647d616cf4c43 Mon Sep 17 00:00:00 2001 From: christ-tt Date: Fri, 31 Oct 2025 04:17:32 +0800 Subject: [PATCH 10/11] fix: modular release & evict --- src/parallax/server/executor.py | 55 +++++++++++++++------------------ 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 49a25534..0819ea6d 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -1116,6 +1116,29 @@ def process_batch( ) return ret + def _release_and_evict_request(self, rid: str): + """Release per-request resources and evict from scheduler. Best-effort, never raises.""" + # Release resources + if self.device == "cuda": + from parallax.sglang.batch_info import release_cuda_request + + try: + release_cuda_request(self.running_batch, rid) + except Exception: + pass + else: + try: + if hasattr(self, "kv_cache_manager") and self.kv_cache_manager is not None: + self.kv_cache_manager.release_request(rid) + except Exception: + pass + + # Evict from scheduler + try: + self.scheduler.evict_request(rid) + except Exception: + pass + def run_loop(self): """The main loop of the executor.""" logger.debug( @@ -1149,29 +1172,7 @@ def run_loop(self): logger.warning( f"Request {rid} exceeded timeout ({req.timeout_s}s). Aborting and releasing resources." ) - # Release resources - if self.device == "cuda": - from parallax.sglang.batch_info import release_cuda_request - - try: - release_cuda_request(self.running_batch, rid) - except Exception: - pass - else: - try: - if ( - hasattr(self, "kv_cache_manager") - and self.kv_cache_manager is not None - ): - self.kv_cache_manager.release_request(rid) - except Exception: - pass - - # Evict from scheduler - try: - self.scheduler.evict_request(rid) - except Exception: - pass + self._release_and_evict_request(rid) # Notify downstream peers to abort if this peer is the first peer in a pipeline if self.is_first_peer and not self.is_last_peer: @@ -1240,13 +1241,7 @@ def run_loop(self): logger.exception(f"Error processing batch: {e}") # Naive error handling: release and evict all requests in the batch for req in batch_to_process: - self.scheduler.evict_request(req.request_id) - if self.device == "cuda": - from parallax.sglang.batch_info import release_cuda_request - - release_cuda_request(self.running_batch, req.request_id) - else: - self.kv_cache_manager.release_request(req.request_id) + self._release_and_evict_request(req.request_id) def run_loop_in_background(self): """Run the executor loop in the background.""" From 6835f9346370c7613c71d19b5ac2b9b03124cd36 Mon Sep 17 00:00:00 2001 From: christ-tt Date: Fri, 31 Oct 2025 12:23:07 +0800 Subject: [PATCH 11/11] fix: request start time to last updated time --- src/parallax/server/request.py | 2 +- src/parallax/server/scheduler.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index a34ac1b4..a628c1a2 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -104,7 +104,7 @@ def __init__( self.sampling_params = sampling_params or SamplingParams() self.abort = False self.ready_for_next_step = False - self.start_time: Optional[float] = None + self.last_updated_time: Optional[float] = None @property def is_finished(self) -> bool: diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 83006dae..c1e28c85 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -123,6 +123,7 @@ def enque_request(self, request: Request | str): return request.ready_for_next_step = True + request.last_updated_time = time.time() # TODO: Handle chunked prefill. if request.is_decoding: rid = request.request_id @@ -216,7 +217,7 @@ def admit_requests(self): continue self._running_requests[rid] = req # Initialize timing for timeout enforcement - req.start_time = time.time() + req.last_updated_time = time.time() logger.debug( f"Admitted to running: rid={rid}, status={req.status}, running_size={len(self._running_requests)}, ready={req.ready_for_next_step}" ) @@ -241,9 +242,9 @@ def get_timed_out_requests(self) -> List[Request]: now = time.time() for req in list(self._running_requests.values()): try: - if req.start_time is None: - raise ValueError("Requests should have start time set.") - if now - req.start_time > self.timeout_s: + if req.last_updated_time is None: + raise ValueError("Requests should have last updated time set.") + if now - req.last_updated_time > self.timeout_s: req.abort = True timed_out.append(req) except Exception: @@ -298,6 +299,7 @@ def form_batch(self) -> List[Request]: # Clear ready flags for decodes included in this batch for r in batch: r.ready_for_next_step = False + r.last_updated_time = time.time() if batch: logger.debug(