Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added protocl version checks #1521

Merged
merged 1 commit into from
Feb 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 25 additions & 6 deletions inference/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ async def enable_prom_metrics():
Instrumentator().instrument(app).expose(app)


@app.on_event("startup")
async def log_inference_protocol_version():
logger.info(f"Inference protocol version: {inference.INFERENCE_PROTOCOL_VERSION}")


# Allow CORS
app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -70,19 +75,34 @@ def manual_chat_repository():
api_key_header = fastapi.Header(None, alias="X-API-Key")


def get_api_key(api_key_header: str = api_key_header) -> str:
if api_key_header is None:
def get_api_key(api_key: str = api_key_header) -> str:
if api_key is None:
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
detail="Missing API key",
)
return api_key_header
return api_key


protocol_version_header = fastapi.Header(None, alias="X-Protocol-Version")


def get_protocol_version(protocol_version: str = protocol_version_header) -> str:
if protocol_version != inference.INFERENCE_PROTOCOL_VERSION:
logger.warning(f"Got worker with incompatible protocol version: {protocol_version}")
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_426_UPGRADE_REQUIRED,
detail=f"Incompatible protocol version: {protocol_version}. Expected: {inference.INFERENCE_PROTOCOL_VERSION}.",
)
return protocol_version


def get_worker(
api_key: str = Depends(get_api_key),
protocol_version: str = Depends(get_protocol_version),
session: sqlmodel.Session = Depends(create_session),
) -> models.DbWorkerEntry:
logger.info(f"get_worker: {api_key=}, {protocol_version=}")
worker = session.exec(
sqlmodel.select(models.DbWorkerEntry).where(models.DbWorkerEntry.api_key == api_key)
).one_or_none()
Expand Down Expand Up @@ -171,7 +191,7 @@ async def create_chat(
request: interface.CreateChatRequest, chat_repository: ChatRepository = Depends(create_chat_repository)
) -> interface.ChatListEntry:
"""Allows a client to create a new chat."""
logger.info(f"Received {request}")
logger.info(f"Received {request=}")
chat = chat_repository.create_chat()
return chat.to_list_entry()

Expand Down Expand Up @@ -243,8 +263,7 @@ async def event_generator(chat_id):
async def work(websocket: fastapi.WebSocket, worker: models.DbWorkerEntry = Depends(get_worker)):
await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
queue_id = f"work:{worker_config.compat_hash}"
work_queue = queueing.RedisQueue(redis_client, queue_id)
work_queue = queueing.work_queue(redis_client, worker_config.compat_hash)
try:
while True:
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
Expand Down
2 changes: 1 addition & 1 deletion inference/server/oasst_inference_server/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MessageRequest(pydantic.BaseModel):

@property
def worker_compat_hash(self) -> str:
return f"{self.model_name}"
return inference.compat_hash(model_name=self.model_name)


class TokenResponseEvent(pydantic.BaseModel):
Expand Down
11 changes: 10 additions & 1 deletion inference/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@


def main():
logger.info(f"Inference protocol version: {inference.INFERENCE_PROTOCOL_VERSION}")

utils.wait_for_inference_server(settings.inference_server_url)

def on_open(ws: websocket.WebSocket):
Expand Down Expand Up @@ -95,6 +97,10 @@ def _prepare_message(message: protocol.ConversationMessage) -> str:
def on_error(ws: websocket.WebSocket, error: Exception):
try:
raise error
except websocket.WebSocketBadStatusException as e:
logger.error(f"Bad status: {e.status_code=} {str(e)=}")
logger.error("Did you provide the correct API key?")
logger.error("Try upgrading the worker to get the latest protocol version")
except Exception:
logger.exception("Error in websocket")

Expand All @@ -107,7 +113,10 @@ def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
on_error=on_error,
on_close=on_close,
on_open=on_open,
header={"X-API-Key": settings.api_key},
header={
"X-API-Key": settings.api_key,
"X-Protocol-Version": inference.INFERENCE_PROTOCOL_VERSION,
},
)

ws.run_forever(dispatcher=rel, reconnect=5)
Expand Down
8 changes: 7 additions & 1 deletion oasst-shared/oasst_shared/schemas/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@

from . import protocol

INFERENCE_PROTOCOL_VERSION = "1"


def compat_hash(*, model_name: str) -> str:
return f"{model_name}"


class WorkerConfig(pydantic.BaseModel):
model_name: str = "distilgpt2"

@property
def compat_hash(self) -> str:
return f"{self.model_name}"
return compat_hash(model_name=self.model_name)


class WorkRequest(pydantic.BaseModel):
Expand Down