diff --git a/inference/server/oasst_inference_server/queueing.py b/inference/server/oasst_inference_server/queueing.py index 535ffc47c1..35b83f01ec 100644 --- a/inference/server/oasst_inference_server/queueing.py +++ b/inference/server/oasst_inference_server/queueing.py @@ -2,6 +2,10 @@ from oasst_inference_server.settings import settings +class QueueFullException(Exception): + pass + + class RedisQueue: def __init__( self, @@ -10,14 +14,19 @@ def __init__( expire: int | None = None, with_counter: bool = False, counter_pos_expire: int = 1, + max_size: int | None = None, ) -> None: self.redis_client = redis_client self.queue_id = queue_id self.expire = expire self.with_counter = with_counter self.counter_pos_expire = counter_pos_expire + self.max_size = max_size or 0 - async def enqueue(self, value: str) -> int | None: + async def enqueue(self, value: str, enforce_max_size: bool = True) -> int | None: + if enforce_max_size and self.max_size > 0: + if await self.get_length() >= self.max_size: + raise QueueFullException() await self.redis_client.rpush(self.queue_id, value) if self.expire is not None: await self.set_expire(self.expire) @@ -71,7 +80,11 @@ def work_queue(redis_client: redis.Redis, worker_compat_hash: str) -> RedisQueue if worker_compat_hash not in settings.allowed_worker_compat_hashes_list: raise ValueError(f"Worker compat hash {worker_compat_hash} not allowed") return RedisQueue( - redis_client, f"work:{worker_compat_hash}", with_counter=True, counter_pos_expire=settings.message_queue_expire + redis_client, + f"work:{worker_compat_hash}", + with_counter=True, + counter_pos_expire=settings.message_queue_expire, + max_size=settings.work_queue_max_size, ) diff --git a/inference/server/oasst_inference_server/routes/chats.py b/inference/server/oasst_inference_server/routes/chats.py index 1f2d05e5dd..e7421662cb 100644 --- a/inference/server/oasst_inference_server/routes/chats.py +++ b/inference/server/oasst_inference_server/routes/chats.py @@ -103,6 +103,11 @@ async def create_assistant_message( await queue.enqueue(assistant_message.id) logger.debug(f"Added {assistant_message.id=} to {queue.queue_id} for {chat_id}") return assistant_message.to_read() + except queueing.QueueFullException: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_503_SERVICE_UNAVAILABLE, + detail="The server is currently busy. Please try again later.", + ) except Exception: logger.exception("Error adding prompter message") return fastapi.Response(status_code=500) diff --git a/inference/server/oasst_inference_server/routes/workers.py b/inference/server/oasst_inference_server/routes/workers.py index ff4955a38f..aea450d9e4 100644 --- a/inference/server/oasst_inference_server/routes/workers.py +++ b/inference/server/oasst_inference_server/routes/workers.py @@ -230,7 +230,7 @@ def _add_receive(ftrs: set): logger.warning(f"Marking {message_id=} as pending since no work was done.") async with deps.manual_chat_repository() as cr: await cr.reset_work(message_id) - await work_queue.enqueue(message_id) + await work_queue.enqueue(message_id, enforce_max_size=False) else: logger.warning(f"Aborting {message_id=}") await abort_message(message_id=message_id, error="Aborted due to worker error.") @@ -312,7 +312,7 @@ async def initiate_work_for_message( logger.exception(f"Error while sending work request to worker: {str(e)}") async with deps.manual_create_session() as session: await cr.reset_work(message_id) - await work_queue.enqueue(message_id) + await work_queue.enqueue(message_id, enforce_max_size=False) raise return work_request diff --git a/inference/server/oasst_inference_server/settings.py b/inference/server/oasst_inference_server/settings.py index 366e0b4a3b..b197519893 100644 --- a/inference/server/oasst_inference_server/settings.py +++ b/inference/server/oasst_inference_server/settings.py @@ -9,6 +9,7 @@ class Settings(pydantic.BaseSettings): redis_db: int = 0 message_queue_expire: int = 60 + work_queue_max_size: int | None = None allowed_worker_compat_hashes: str = "*"