diff --git a/inference/server/main.py b/inference/server/main.py index dc8c734ab3..34acefcb82 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -29,6 +29,11 @@ async def enable_prom_metrics(): Instrumentator().instrument(app).expose(app) +@app.on_event("startup") +async def log_inference_protocol_version(): + logger.info(f"Inference protocol version: {inference.INFERENCE_PROTOCOL_VERSION}") + + # Allow CORS app.add_middleware( CORSMiddleware, @@ -70,19 +75,34 @@ def manual_chat_repository(): api_key_header = fastapi.Header(None, alias="X-API-Key") -def get_api_key(api_key_header: str = api_key_header) -> str: - if api_key_header is None: +def get_api_key(api_key: str = api_key_header) -> str: + if api_key is None: raise fastapi.HTTPException( status_code=fastapi.status.HTTP_401_UNAUTHORIZED, detail="Missing API key", ) - return api_key_header + return api_key + + +protocol_version_header = fastapi.Header(None, alias="X-Protocol-Version") + + +def get_protocol_version(protocol_version: str = protocol_version_header) -> str: + if protocol_version != inference.INFERENCE_PROTOCOL_VERSION: + logger.warning(f"Got worker with incompatible protocol version: {protocol_version}") + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_426_UPGRADE_REQUIRED, + detail=f"Incompatible protocol version: {protocol_version}. Expected: {inference.INFERENCE_PROTOCOL_VERSION}.", + ) + return protocol_version def get_worker( api_key: str = Depends(get_api_key), + protocol_version: str = Depends(get_protocol_version), session: sqlmodel.Session = Depends(create_session), ) -> models.DbWorkerEntry: + logger.info(f"get_worker: {api_key=}, {protocol_version=}") worker = session.exec( sqlmodel.select(models.DbWorkerEntry).where(models.DbWorkerEntry.api_key == api_key) ).one_or_none() @@ -171,7 +191,7 @@ async def create_chat( request: interface.CreateChatRequest, chat_repository: ChatRepository = Depends(create_chat_repository) ) -> interface.ChatListEntry: """Allows a client to create a new chat.""" - logger.info(f"Received {request}") + logger.info(f"Received {request=}") chat = chat_repository.create_chat() return chat.to_list_entry() @@ -243,8 +263,7 @@ async def event_generator(chat_id): async def work(websocket: fastapi.WebSocket, worker: models.DbWorkerEntry = Depends(get_worker)): await websocket.accept() worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) - queue_id = f"work:{worker_config.compat_hash}" - work_queue = queueing.RedisQueue(redis_client, queue_id) + work_queue = queueing.work_queue(redis_client, worker_config.compat_hash) try: while True: if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED: diff --git a/inference/server/oasst_inference_server/interface.py b/inference/server/oasst_inference_server/interface.py index 55c24a3dbb..9349a9be48 100644 --- a/inference/server/oasst_inference_server/interface.py +++ b/inference/server/oasst_inference_server/interface.py @@ -11,7 +11,7 @@ class MessageRequest(pydantic.BaseModel): @property def worker_compat_hash(self) -> str: - return f"{self.model_name}" + return inference.compat_hash(model_name=self.model_name) class TokenResponseEvent(pydantic.BaseModel): diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index 811c08050f..c46606a3ed 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -12,6 +12,8 @@ def main(): + logger.info(f"Inference protocol version: {inference.INFERENCE_PROTOCOL_VERSION}") + utils.wait_for_inference_server(settings.inference_server_url) def on_open(ws: websocket.WebSocket): @@ -95,6 +97,10 @@ def _prepare_message(message: protocol.ConversationMessage) -> str: def on_error(ws: websocket.WebSocket, error: Exception): try: raise error + except websocket.WebSocketBadStatusException as e: + logger.error(f"Bad status: {e.status_code=} {str(e)=}") + logger.error("Did you provide the correct API key?") + logger.error("Try upgrading the worker to get the latest protocol version") except Exception: logger.exception("Error in websocket") @@ -107,7 +113,10 @@ def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str): on_error=on_error, on_close=on_close, on_open=on_open, - header={"X-API-Key": settings.api_key}, + header={ + "X-API-Key": settings.api_key, + "X-Protocol-Version": inference.INFERENCE_PROTOCOL_VERSION, + }, ) ws.run_forever(dispatcher=rel, reconnect=5) diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index f8a94fc1bd..d255a37350 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -5,13 +5,19 @@ from . import protocol +INFERENCE_PROTOCOL_VERSION = "1" + + +def compat_hash(*, model_name: str) -> str: + return f"{model_name}" + class WorkerConfig(pydantic.BaseModel): model_name: str = "distilgpt2" @property def compat_hash(self) -> str: - return f"{self.model_name}" + return compat_hash(model_name=self.model_name) class WorkRequest(pydantic.BaseModel):