From 5f951b0a15db049cbfe68d906635df2e4b6be0b6 Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Tue, 18 Apr 2023 05:21:54 +0100 Subject: [PATCH] Unify title update and visibility update inference endpoints (#2627) Close #2581 --- .../oasst_inference_server/routes/chats.py | 26 +++++++------------ .../oasst_inference_server/schemas/chat.py | 9 +++---- .../user_chat_repository.py | 21 +++++++++------ 3 files changed, 26 insertions(+), 30 deletions(-) diff --git a/inference/server/oasst_inference_server/routes/chats.py b/inference/server/oasst_inference_server/routes/chats.py index 499b9dbd94..be06227596 100644 --- a/inference/server/oasst_inference_server/routes/chats.py +++ b/inference/server/oasst_inference_server/routes/chats.py @@ -259,25 +259,19 @@ async def handle_create_report( return fastapi.Response(status_code=500) -@router.put("/{chat_id}/title") -async def handle_update_title( +@router.put("/{chat_id}") +async def handle_update_chat( chat_id: str, - request: chat_schema.ChatUpdateTitleRequest, + request: chat_schema.ChatUpdateRequest, ucr: deps.UserChatRepository = fastapi.Depends(deps.create_user_chat_repository), ) -> fastapi.Response: - """Allows the client to update a chat title.""" + """Allows the client to update a chat.""" try: - await ucr.update_title(chat_id=chat_id, title=request.title) + await ucr.update_chat( + chat_id=chat_id, + title=request.title, + hidden=request.hidden, + ) except Exception: - logger.exception("Error when updating chat title") + logger.exception("Error when updating chat") return fastapi.Response(status_code=500) - - -@router.put("/{chat_id}/hide") -async def update_visibility( - chat_id: str, - request: chat_schema.ChatUpdateVisibilityRequest, - ucr: UserChatRepository = Depends(deps.create_user_chat_repository), -): - await ucr.update_visibility(chat_id, request.hidden) - return fastapi.Response(status_code=200) diff --git a/inference/server/oasst_inference_server/schemas/chat.py b/inference/server/oasst_inference_server/schemas/chat.py index b906041d56..dc0147ab54 100644 --- a/inference/server/oasst_inference_server/schemas/chat.py +++ b/inference/server/oasst_inference_server/schemas/chat.py @@ -84,9 +84,6 @@ def __init__(self, message: inference.MessageRead): self.message = message -class ChatUpdateTitleRequest(pydantic.BaseModel): - title: pydantic.constr(max_length=100) - - -class ChatUpdateVisibilityRequest(pydantic.BaseModel): - hidden: bool +class ChatUpdateRequest(pydantic.BaseModel): + title: pydantic.constr(max_length=100) | None = None + hidden: bool | None = None diff --git a/inference/server/oasst_inference_server/user_chat_repository.py b/inference/server/oasst_inference_server/user_chat_repository.py index ddad7a6539..216c679822 100644 --- a/inference/server/oasst_inference_server/user_chat_repository.py +++ b/inference/server/oasst_inference_server/user_chat_repository.py @@ -228,16 +228,21 @@ async def add_report(self, message_id: str, reason: str, report_type: inference. await self.session.refresh(report) return report - async def update_title(self, chat_id: str, title: str) -> None: - logger.info(f"Updating title of chat {chat_id=}: {title=}") + async def update_chat( + self, + chat_id: str, + title: str | None = None, + hidden: bool | None = None, + ) -> None: + logger.info(f"Updating chat {chat_id=}: {title=} {hidden=}") chat = await self.get_chat_by_id(chat_id=chat_id, include_messages=False) - chat.title = title - await self.session.commit() + if title is not None: + logger.info(f"Updating title of chat {chat_id=}: {title=}") + chat.title = title - async def update_visibility(self, chat_id: str, hidden: bool) -> None: - logger.info(f"Setting chat {chat_id=} to {'hidden' if hidden else 'visible'}") - chat = await self.get_chat_by_id(chat_id=chat_id, include_messages=False) + if hidden is not None: + logger.info(f"Setting chat {chat_id=} to {'hidden' if hidden else 'visible'}") + chat.hidden = hidden - chat.hidden = hidden await self.session.commit()