diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 4b6b6d6c..0819ea6d 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -80,6 +80,7 @@ def __init__( prefill_priority: int = 0, micro_batch_ratio: int = 2, scheduler_wait_ms: int = 500, + request_timeout_s: Optional[int] = 600, # Metrics Configs layer_latency_update_every: int = 4096, # KV Cache Configs @@ -242,6 +243,8 @@ def __init__( micro_batch_ratio=micro_batch_ratio, is_first_peer=self.is_first_peer, tokenizer=self.tokenizer, + kv_cache_manager=self.kv_cache_manager if self.device == "mlx" else None, + request_timeout_s=request_timeout_s, ) logger.debug( f"Scheduler initialized (max_batch_size={max_batch_size}, max_tokens={max_num_tokens_per_batch}, wait_ms={scheduler_wait_ms})" @@ -781,7 +784,7 @@ def _handle_cuda_input_requests(self, requests: List[Request]): req, IntermediateRequest ), "Non-first peers must receive IntermediateRequests." if req.is_finished or req.hidden_states is None: - self.scheduler.evict_request(req.request_id, req.status) + self.scheduler.evict_request(req.request_id) release_cuda_request(self.running_batch, req.request_id) if not self.is_last_peer: self.finished_batch.append(req) @@ -805,8 +808,6 @@ def _handle_input_requests(self, requests: List[Request]): # or IntermediateRequests from the last peer. for req in requests: if isinstance(req, InitialRequest): - if not self.kv_cache_manager.has_request(req.request_id): - self.kv_cache_manager.add_request(req, req.total_length) self.scheduler.enque_request(req) elif isinstance(req, IntermediateRequest): original_req = self.scheduler.get_running_request(req.request_id) @@ -876,14 +877,12 @@ def _handle_input_requests(self, requests: List[Request]): f"kv cache manager has {self.kv_cache_manager.tokens_in_cache} tokens, " f"memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB" ) - self.scheduler.evict_request(req.request_id, req.status) + self.scheduler.evict_request(req.request_id) if not self.is_last_peer: self.finished_batch.append(req) else: # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) - if not self.kv_cache_manager.has_request(req.request_id): - self.kv_cache_manager.add_request(req, req.total_length) def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Request: """Handle request state changes both inter and intra peers. @@ -1058,7 +1057,7 @@ def _process_batch_mlx( ] # k_caches shape: (num_layers, B, num_kv_heads, L_padded, head_dim) logger.debug( - f"Processed batch with {len(prepared_inputs['requests'])} requests, " + f"Processing batch with {len(prepared_inputs['requests'])} requests, " f"request status: {prepared_inputs['requests'][0].status}, " f"hidden_states shape: {hidden_states.shape}, " f"k_caches shape: {k_caches.shape}, " @@ -1117,6 +1116,29 @@ def process_batch( ) return ret + def _release_and_evict_request(self, rid: str): + """Release per-request resources and evict from scheduler. Best-effort, never raises.""" + # Release resources + if self.device == "cuda": + from parallax.sglang.batch_info import release_cuda_request + + try: + release_cuda_request(self.running_batch, rid) + except Exception: + pass + else: + try: + if hasattr(self, "kv_cache_manager") and self.kv_cache_manager is not None: + self.kv_cache_manager.release_request(rid) + except Exception: + pass + + # Evict from scheduler + try: + self.scheduler.evict_request(rid) + except Exception: + pass + def run_loop(self): """The main loop of the executor.""" logger.debug( @@ -1139,12 +1161,25 @@ def run_loop(self): ) self.finished_batch = [] - # 4. Check if we should form a batch - if not self.scheduler.should_dispatch(): - time.sleep(0.01) # prevent busy waiting - continue + # 4. Admit requests into running set up to capacity, then form batch + self.scheduler.admit_requests() + # 4.1 Check for request timeouts and abort timed out requests + try: + timed_out_reqs = self.scheduler.get_timed_out_requests() + if timed_out_reqs: + for req in timed_out_reqs: + rid = req.request_id + logger.warning( + f"Request {rid} exceeded timeout ({req.timeout_s}s). Aborting and releasing resources." + ) + self._release_and_evict_request(rid) - # 5. Form a batch from the scheduler's queue + # Notify downstream peers to abort if this peer is the first peer in a pipeline + if self.is_first_peer and not self.is_last_peer: + self.finished_batch.append(req) + except Exception: + # Non-fatal; continue serving + pass batch_to_process = self.scheduler.form_batch() if not batch_to_process: continue @@ -1206,13 +1241,7 @@ def run_loop(self): logger.exception(f"Error processing batch: {e}") # Naive error handling: release and evict all requests in the batch for req in batch_to_process: - self.scheduler.evict_request(req.request_id, req.status) - if self.device == "cuda": - from parallax.sglang.batch_info import release_cuda_request - - release_cuda_request(self.running_batch, req.request_id) - else: - self.kv_cache_manager.release_request(req.request_id) + self._release_and_evict_request(req.request_id) def run_loop_in_background(self): """Run the executor loop in the background.""" diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index f31b4959..a628c1a2 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -103,6 +103,8 @@ def __init__( self.routing_table = routing_table self.sampling_params = sampling_params or SamplingParams() self.abort = False + self.ready_for_next_step = False + self.last_updated_time: Optional[float] = None @property def is_finished(self) -> bool: diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index fd500b81..c1e28c85 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -1,15 +1,29 @@ """ -Scheduling requests to form batches.sche - -A scheduler will maintain a Priority Queue for request waiting pool. -We support continuous batching, and similar to TensorRT-LLM, - we favors prefill requests over decode requests. +Continuous Batching Scheduler. + +State managed by the scheduler: + 1. Prefill Wait Queue (FIFO): incoming prefill requests waiting for admission; + 2. Running Requests: inflight requests with KV-cache residency; +Main `form_batch` function will return the concrete batch chosen for the next model forward. + +We use an explicit 2-Phase approach: + * Phase 1 (Admission): wait queue -> running requests + Implemented by `admit_requests`. We admit requests when capacity + allows (e.g., max concurrent requests, memory availability). Admitted + requests get KV-cache residency and become inflight. + * Phase 2 (Batching): running requests -> active batch for actual forward + Implemented by `form_batch`. We prioritize PREFILL requests + first within `max_num_tokens_per_batch` and `micro_batch_size`, + then include DECODE requests that are marked ready for the next decode step. + +Our scheduler also handles tokenization and pre-processing for the First Peer's requests. """ -import heapq import time -from typing import Dict, List, Literal, Optional, Tuple +from collections import OrderedDict +from typing import Dict, List, Optional +from parallax.server.kv_cache import KVCacheManager from parallax.server.metrics import update_metrics from parallax.server.request import InitialRequest, Request, RequestStatus from parallax_utils.logging_config import get_logger @@ -19,29 +33,31 @@ class Scheduler: """ - A simple scheduler to manage requests and form them into batches. - This scheduler is designed to handle requests in a FIFO manner. + 2-Phase approach: + * Phase 1: wait queue -> running requests (all inflight requests) + * Phase 2: running requests -> active batch (actual model forward) """ def __init__( self, max_batch_size: int = 16, max_num_tokens_per_batch: int = 4096, - prefill_priority: Literal[0, 1] = 0, scheduler_wait_ms: int = 200, micro_batch_ratio: int = 2, is_first_peer: bool = False, + kv_cache_manager: Optional[KVCacheManager] = None, + request_timeout_s: Optional[int] = 600, **kwargs, ): """ Args: - max_batch_size: Maximum number of running requests; + max_batch_size: Maximum number of running / inflight requests; max_num_tokens_per_batch: Maxmimum number of prefill + decode tokens in a single batch; - prefill_priority: Priority for prefill requests, - default 0 for prefill, 1 for decode, 0 for higher priority; scheduler_wait_ms: The minimum time to wait before dispatching a batch; - micro_batch_ratio: micro_batch_size = max_batch_size // micro_batch_ratio - tokenizer: The tokenizer to use for the model. + micro_batch_ratio: micro_batch_size = max_batch_size // micro_batch_ratio; + tokenizer: The tokenizer to use for the model; + kv_cache_manager: The KV cache manager to use for the scheduler. + request_timeout_s: timeout for each inflight request (default 10mins). """ self.max_batch_size = max_batch_size self.max_num_tokens_per_batch = max_num_tokens_per_batch @@ -55,15 +71,15 @@ def __init__( self.max_new_tokens = kwargs.get("max_new_tokens", 512) self.max_total_length = kwargs.get("max_total_length", 1024) - # Priority queue: (priority, arrival_time, request_id, request_object) - self._request_queue: List[Tuple[int, float, str, Request]] = [] + # Prefill wait queue (FIFO) for admission + self._wait_queue: List[Request] = [] # Keeps track of all in-flight requests - self._running_requests: Dict[str, Request] = {} + self._running_requests: Dict[str, Request] = OrderedDict() + + self.kv_cache_manager = kv_cache_manager + # Default timeout for requests if not set on request object + self.request_timeout_s = request_timeout_s - self.priority_map = { - RequestStatus.PREFILLING: prefill_priority, - RequestStatus.DECODING: 1 - prefill_priority, - } self._last_dispatch_ts = time.time() # Track last reported running requests to avoid redundant metric updates self._last_reported_running_requests: int = 0 @@ -75,18 +91,13 @@ def __init__( @property def num_queued_requests(self) -> int: """Get the number of requests in the scheduler.""" - return len(self._request_queue) + return len(self._wait_queue) @property def num_running_requests(self) -> int: """Get the number of requests currently being processed.""" return len(self._running_requests) - @property - def has_pending_requests(self) -> bool: - """Check if there are any pending requests in the scheduler.""" - return len(self._request_queue) > 0 - def get_running_request(self, request_id: str) -> Optional[Request]: """Gets a request that is currently in the running state.""" return self._running_requests.get(request_id) @@ -100,7 +111,7 @@ def _prompt_string_to_request(self, request_str: str) -> InitialRequest: ) def enque_request(self, request: Request | str): - """Add a request to the scheduler.""" + """Enque a request to the scheduler's wait queue.""" if isinstance(request, str): request = self._prompt_string_to_request(request) @@ -110,25 +121,37 @@ def enque_request(self, request: Request | str): f"{request.status}. Not adding to the scheduler." ) return - arrival_time = time.time() - priority = self.priority_map.get(request.status, 1) - heapq.heappush(self._request_queue, (priority, arrival_time, request.request_id, request)) - logger.debug(f"Request {request.request_id} added to the scheduler.") - logger.debug(f"Scheduler queue size: {len(self._request_queue)}") - # Running count does not change on enqueue; do not update metrics here - - def evict_request(self, request_id: str, status: Optional[RequestStatus] = None): + + request.ready_for_next_step = True + request.last_updated_time = time.time() + # TODO: Handle chunked prefill. + if request.is_decoding: + rid = request.request_id + if rid not in self._running_requests: + raise ValueError( + f"Decode request {rid} must already be admitted (in running requests)." + ) + # Merge incoming decode readiness/state into the existing running request + self._running_requests[rid] = request + # Update recency ordering so earlier-ready decodes are encountered first during batching + self._running_requests.move_to_end(rid) + logger.debug(f"Decode request {rid} marked ready for next decode.") + return + + self._wait_queue.append(request) + logger.debug( + f"Prefill request {request.request_id} added to the prefill wait queue (size={len(self._wait_queue)})." + ) + + def evict_request(self, request_id: str): """Removes a request from the scheduler's running queue.""" - _ = status # status is used by the first peer's logic but not here. if request_id in self._running_requests: self._running_requests.pop(request_id) logger.debug(f"Evicted request {request_id} from scheduler.") # Update metrics only if running count changed since last report try: curr = self.num_running_requests - # if curr != self._last_reported_running_requests: update_metrics(current_requests=curr) - # self._last_reported_running_requests = curr except Exception: pass else: @@ -173,54 +196,115 @@ def check_and_update_request_status(self, request: InitialRequest) -> bool: return finished - def should_dispatch(self) -> bool: - """Helper check if the scheduler should dispatch a batch.""" - waited = (time.time() - self._last_dispatch_ts) * 1000 >= self.scheduler_wait_ms - queued = self.num_queued_requests >= self.micro_batch_size - return waited or queued + def admit_requests(self): + """Move requests from wait queue into running (inflight) set, up to capacity. + + Pushes admitted requests directly into the running set. + """ + while self._wait_queue and len(self._running_requests) < self.max_batch_size: + req = self._wait_queue.pop(0) + rid = req.request_id + if rid in self._running_requests: + # Already inflight; chunked-prefill, skip + continue + # Check kv cache pool + if self.kv_cache_manager is not None: + if not self.kv_cache_manager.has_request(req.request_id): + if not self.kv_cache_manager.add_request(req, req.total_length): + logger.warning( + f"Request {rid} can't be admit to running batch due to KV cache size." + ) + continue + self._running_requests[rid] = req + # Initialize timing for timeout enforcement + req.last_updated_time = time.time() + logger.debug( + f"Admitted to running: rid={rid}, status={req.status}, running_size={len(self._running_requests)}, ready={req.ready_for_next_step}" + ) + + # Reflect current running requests metric after admission + try: + curr = self.num_running_requests + if curr != self._last_reported_running_requests: + update_metrics(current_requests=curr) + self._last_reported_running_requests = curr + except Exception: + pass + + return + + def get_timed_out_requests(self) -> List[Request]: + """Return running requests that exceeded their timeout and mark them aborted. + + This does not evict or release resources; callers must handle cleanup. + """ + timed_out: List[Request] = [] + now = time.time() + for req in list(self._running_requests.values()): + try: + if req.last_updated_time is None: + raise ValueError("Requests should have last updated time set.") + if now - req.last_updated_time > self.timeout_s: + req.abort = True + timed_out.append(req) + except Exception: + continue + return timed_out def form_batch(self) -> List[Request]: - """Get the next batch of requests. + """Form the active batch for the next forward pass. - At-most `micro_batch_size` requests will be returned. + - Select prefills first (FIFO by admission), then decodes that are ready + following the OrderedDict iteration order where ready decodes are + moved-to-end upon readiness, while respecting micro_batch_size and + max_num_tokens_per_batch. """ - if not self.has_pending_requests: + self.admit_requests() + if not self._running_requests: return [] inflight_tokens = 0 - - batch = [] - save_index = [] - for index, request in enumerate(self._request_queue): + batch: List[Request] = [] + + # Prefill candidates: preserve admission order via OrderedDict iteration + prefill_candidates = [] + decode_candidates = [] + for req in self._running_requests.values(): + if req.ready_for_next_step: + if req.is_prefill: + prefill_candidates.append(req) + elif req.is_decoding: + decode_candidates.append(req) + + # 1) Fill with prefills first + for req in prefill_candidates: if len(batch) >= self.micro_batch_size: - save_index.append(index) + break + cost = req.prompt_len + if cost + inflight_tokens > self.max_num_tokens_per_batch: continue - _, _, rid, req = request + batch.append(req) + inflight_tokens += cost - cost = req.prompt_len if req.is_prefill else 1 + # 2) Fill remaining with ready decodes + for req in decode_candidates: + if len(batch) >= self.micro_batch_size: + break + cost = 1 if cost + inflight_tokens > self.max_num_tokens_per_batch: - save_index.append(index) continue - - if rid not in self._running_requests: - if len(self._running_requests) >= self.max_batch_size: - save_index.append(index) - continue - batch.append(req) - self._running_requests[rid] = req - inflight_tokens += cost - self._request_queue = [self._request_queue[i] for i in save_index] - - # Reflect current running requests metric after forming the batch - try: - curr = self.num_running_requests - if curr != self._last_reported_running_requests: - update_metrics(current_requests=curr) - self._last_reported_running_requests = curr - except Exception: - pass + # Clear ready flags for decodes included in this batch + for r in batch: + r.ready_for_next_step = False + r.last_updated_time = time.time() + if batch: + logger.debug( + "Form batch selected=%s inflight_tokens=%d", + [f"{r.request_id}:{r.status}, ready:{r.ready_for_next_step}" for r in batch], + inflight_tokens, + ) return batch diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 5b1429e0..9686bf9f 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -126,6 +126,13 @@ def parse_args() -> argparse.Namespace: "--scheduler-wait-ms", type=int, default=500, help="Scheduler wait time in milliseconds" ) + parser.add_argument( + "--request-timeout-s", + type=int, + default=600, + help="Per-request timeout in seconds before automatic abort", + ) + # GPU/SGLang specialized configuration parser.add_argument( "--attention-backend", @@ -213,6 +220,9 @@ def validate_args(args: argparse.Namespace) -> None: if args.scheduler_wait_ms < 0: raise ValueError("scheduler_wait_ms must be non-negative") + if getattr(args, "request_timeout_s", None) is not None and args.request_timeout_s <= 0: + raise ValueError("request_timeout_s must be positive") + # Validate supported dtypes dtype_list = [ "float16", diff --git a/tests/test_batch_scheduler.py b/tests/test_batch_scheduler.py new file mode 100644 index 00000000..80fa7e9c --- /dev/null +++ b/tests/test_batch_scheduler.py @@ -0,0 +1,107 @@ +from parallax.server.request import InitialRequest, Request, RequestStatus +from parallax.server.scheduler import Scheduler + + +class FakeKVCacheManager: + def __init__(self, allow: bool = True): + self.allow = allow + self._reqs = set() + + def has_request(self, request_id: str) -> bool: + return request_id in self._reqs + + def add_request(self, request: Request, num_tokens: int = 0) -> bool: + if not self.allow: + return False + self._reqs.add(request.request_id) + return True + + +def make_prefill(rid: str, prompt_len: int) -> InitialRequest: + return InitialRequest(request_id=rid, input_ids=[0] * prompt_len) + + +def make_decode(rid: str, ready: bool = True) -> Request: + r = Request(request_id=rid, status=RequestStatus.DECODING) + r.ready_for_next_step = ready + return r + + +def test_prefill_fifo_and_micro_batch(): + sched = Scheduler(max_batch_size=8, max_num_tokens_per_batch=10_000, micro_batch_ratio=1) + # micro_batch_size = max_batch_size // ratio = 8 + # Enqueue 3 prefills in order + r1 = make_prefill("r1", 5) + r2 = make_prefill("r2", 6) + r3 = make_prefill("r3", 7) + sched.enque_request(r1) + sched.enque_request(r2) + sched.enque_request(r3) + + batch = sched.form_batch() + ids = [r.request_id for r in batch] + assert ids[:3] == ["r1", "r2", "r3"] + + +def test_decode_ready_order_and_prefill_first(): + # micro_batch_size = 3 + sched = Scheduler(max_batch_size=3, max_num_tokens_per_batch=10_000, micro_batch_ratio=1) + + # Two decodes already running + d1 = make_decode("d1") + d2 = make_decode("d2") + sched._running_requests[d1.request_id] = d1 + sched._running_requests[d2.request_id] = d2 + + # One prefill in queue + p1 = make_prefill("p1", 8) + sched.enque_request(p1) + + # Mark d1 ready first, then d2 + sched.enque_request(d1) # sets ready_for_next_step + LRU move_to_end + sched.enque_request(d2) + + sched.admit_requests() + batch = sched.form_batch() + ids = [r.request_id for r in batch] + + # Prefill first, then decodes in the order they became ready + assert ids == ["p1", "d1", "d2"] + + +def test_token_budget_prefill_skipped_decode_taken(): + # Token budget too small for prefill, but enough for decodes (cost=1) + sched = Scheduler(max_batch_size=2, max_num_tokens_per_batch=1, micro_batch_ratio=1) + + # One large prefill + p_big = make_prefill("p_big", 5) + sched.enque_request(p_big) + + # One ready decode already running + d = make_decode("d") + sched._running_requests[d.request_id] = d + sched.enque_request(d) + + batch = sched.form_batch() + ids = [r.request_id for r in batch] + assert ids == ["d"] + # ready flag should be reset after batching + assert getattr(d, "ready_for_next_step", False) is False + + +def test_kv_cache_admission_guard_blocks_prefill(): + # A KV manager that rejects additions + kv_mgr = FakeKVCacheManager(allow=False) + sched = Scheduler( + max_batch_size=2, + max_num_tokens_per_batch=100, + micro_batch_ratio=1, + kv_cache_manager=kv_mgr, + ) + p = make_prefill("p", 4) + sched.enque_request(p) + + # Admission should fail and running set remains empty; batch should be empty + batch = sched.form_batch() + assert len(batch) == 0 + assert sched.num_running_requests == 0 diff --git a/tests/test_executor.py b/tests/test_executor.py index 1dd16093..bbb117b4 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -48,6 +48,7 @@ def test_decode_pipeline_multiple_steps(start_layer, end_layer, num_decode_steps ] executor_peer1._handle_input_requests(initial_requests) + executor_peer1.scheduler.admit_requests() prefill_batch_p1 = executor_peer1.scheduler.form_batch() prefill_inputs_p1 = executor_peer1._prepare_batch_inputs(prefill_batch_p1) assert prefill_inputs_p1 is not None, "Failed to prepare batch inputs" @@ -62,6 +63,7 @@ def test_decode_pipeline_multiple_steps(start_layer, end_layer, num_decode_steps # send to next peer executor_peer2._handle_input_requests(prefill_reqs_p2) + executor_peer2.scheduler.admit_requests() prefill_batch_p2 = executor_peer2.scheduler.form_batch() prefill_inputs_p2 = executor_peer2._prepare_batch_inputs(prefill_batch_p2) assert prefill_inputs_p2 is not None, "Failed to prepare batch inputs" @@ -82,6 +84,7 @@ def test_decode_pipeline_multiple_steps(start_layer, end_layer, num_decode_steps executor_peer1._handle_input_requests(feedback_reqs) # 3. Peer 1: form and process decode batch + executor_peer1.scheduler.admit_requests() decode_batch_p1 = executor_peer1.scheduler.form_batch() decode_inputs_p1 = executor_peer1._prepare_batch_inputs(decode_batch_p1) assert decode_inputs_p1 is not None, "Failed to prepare batch inputs" @@ -99,6 +102,7 @@ def test_decode_pipeline_multiple_steps(start_layer, end_layer, num_decode_steps # 5. Peer 2: process decode batch to get next tokens executor_peer2._handle_input_requests(decode_reqs_p2) + executor_peer2.scheduler.admit_requests() decode_batch_p2 = executor_peer2.scheduler.form_batch() decode_inputs_p2 = executor_peer2._prepare_batch_inputs(decode_batch_p2) assert decode_inputs_p2 is not None, "Failed to prepare batch inputs"