diff --git a/src/litserve/server.py b/src/litserve/server.py index 05417a48..41d1db8b 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -26,7 +26,7 @@ import time import os import shutil -from typing import Sequence, Optional, Union, List +from typing import Sequence, Optional, Union, List, Dict import uuid from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks, Request, Response @@ -68,16 +68,28 @@ def get_batch_from_uid(uids, lit_api, request_buffer): return batches -def collate_requests(lit_api, request_queue: Queue, request_buffer, max_batch_size, batch_timeout): +def collate_requests( + lit_api: LitAPI, request_queue: Queue, request_buffer: Dict, max_batch_size: int, batch_timeout: float +) -> Optional[List]: uids = [] entered_at = time.time() - while (batch_timeout - (time.time() - entered_at) > 0) and len(uids) < max_batch_size: + end_time = entered_at + batch_timeout + + while time.time() < end_time and len(uids) < max_batch_size: + remaining_time = end_time - time.time() + if remaining_time <= 0: + break + try: - uid = request_queue.get(timeout=0.001) + uid = request_queue.get(timeout=min(remaining_time, 0.001)) uids.append(uid) - except (Empty, ValueError): + except Empty: continue - return get_batch_from_uid(uids, lit_api, request_buffer) + + if uids: + return get_batch_from_uid(uids, lit_api, request_buffer) + + return None def run_batched_loop(lit_api, lit_spec, request_queue: Queue, request_buffer, max_batch_size, batch_timeout):