Skip to content

Commit

Permalink
added protocl version checks (#1521)
Browse files Browse the repository at this point in the history
  • Loading branch information
yk committed Feb 12, 2023
1 parent dd2f17d commit 4f92001
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 9 deletions.
31 changes: 25 additions & 6 deletions inference/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ async def enable_prom_metrics():
Instrumentator().instrument(app).expose(app)


@app.on_event("startup")
async def log_inference_protocol_version():
logger.info(f"Inference protocol version: {inference.INFERENCE_PROTOCOL_VERSION}")


# Allow CORS
app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -70,19 +75,34 @@ def manual_chat_repository():
api_key_header = fastapi.Header(None, alias="X-API-Key")


def get_api_key(api_key_header: str = api_key_header) -> str:
if api_key_header is None:
def get_api_key(api_key: str = api_key_header) -> str:
if api_key is None:
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
detail="Missing API key",
)
return api_key_header
return api_key


protocol_version_header = fastapi.Header(None, alias="X-Protocol-Version")


def get_protocol_version(protocol_version: str = protocol_version_header) -> str:
if protocol_version != inference.INFERENCE_PROTOCOL_VERSION:
logger.warning(f"Got worker with incompatible protocol version: {protocol_version}")
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_426_UPGRADE_REQUIRED,
detail=f"Incompatible protocol version: {protocol_version}. Expected: {inference.INFERENCE_PROTOCOL_VERSION}.",
)
return protocol_version


def get_worker(
api_key: str = Depends(get_api_key),
protocol_version: str = Depends(get_protocol_version),
session: sqlmodel.Session = Depends(create_session),
) -> models.DbWorkerEntry:
logger.info(f"get_worker: {api_key=}, {protocol_version=}")
worker = session.exec(
sqlmodel.select(models.DbWorkerEntry).where(models.DbWorkerEntry.api_key == api_key)
).one_or_none()
Expand Down Expand Up @@ -171,7 +191,7 @@ async def create_chat(
request: interface.CreateChatRequest, chat_repository: ChatRepository = Depends(create_chat_repository)
) -> interface.ChatListEntry:
"""Allows a client to create a new chat."""
logger.info(f"Received {request}")
logger.info(f"Received {request=}")
chat = chat_repository.create_chat()
return chat.to_list_entry()

Expand Down Expand Up @@ -243,8 +263,7 @@ async def event_generator(chat_id):
async def work(websocket: fastapi.WebSocket, worker: models.DbWorkerEntry = Depends(get_worker)):
await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
queue_id = f"work:{worker_config.compat_hash}"
work_queue = queueing.RedisQueue(redis_client, queue_id)
work_queue = queueing.work_queue(redis_client, worker_config.compat_hash)
try:
while True:
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
Expand Down
2 changes: 1 addition & 1 deletion inference/server/oasst_inference_server/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MessageRequest(pydantic.BaseModel):

@property
def worker_compat_hash(self) -> str:
return f"{self.model_name}"
return inference.compat_hash(model_name=self.model_name)


class TokenResponseEvent(pydantic.BaseModel):
Expand Down
11 changes: 10 additions & 1 deletion inference/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@


def main():
logger.info(f"Inference protocol version: {inference.INFERENCE_PROTOCOL_VERSION}")

utils.wait_for_inference_server(settings.inference_server_url)

def on_open(ws: websocket.WebSocket):
Expand Down Expand Up @@ -95,6 +97,10 @@ def _prepare_message(message: protocol.ConversationMessage) -> str:
def on_error(ws: websocket.WebSocket, error: Exception):
try:
raise error
except websocket.WebSocketBadStatusException as e:
logger.error(f"Bad status: {e.status_code=} {str(e)=}")
logger.error("Did you provide the correct API key?")
logger.error("Try upgrading the worker to get the latest protocol version")
except Exception:
logger.exception("Error in websocket")

Expand All @@ -107,7 +113,10 @@ def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
on_error=on_error,
on_close=on_close,
on_open=on_open,
header={"X-API-Key": settings.api_key},
header={
"X-API-Key": settings.api_key,
"X-Protocol-Version": inference.INFERENCE_PROTOCOL_VERSION,
},
)

ws.run_forever(dispatcher=rel, reconnect=5)
Expand Down
8 changes: 7 additions & 1 deletion oasst-shared/oasst_shared/schemas/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@

from . import protocol

INFERENCE_PROTOCOL_VERSION = "1"


def compat_hash(*, model_name: str) -> str:
return f"{model_name}"


class WorkerConfig(pydantic.BaseModel):
model_name: str = "distilgpt2"

@property
def compat_hash(self) -> str:
return f"{self.model_name}"
return compat_hash(model_name=self.model_name)


class WorkRequest(pydantic.BaseModel):
Expand Down

0 comments on commit 4f92001

Please sign in to comment.