Skip to content

Commit

Permalink
Unify title update and visibility update inference endpoints (#2627)
Browse files Browse the repository at this point in the history
Close #2581
  • Loading branch information
olliestanley committed Apr 18, 2023
1 parent 533c53f commit 5f951b0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 30 deletions.
26 changes: 10 additions & 16 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,25 +259,19 @@ async def handle_create_report(
return fastapi.Response(status_code=500)


@router.put("/{chat_id}/title")
async def handle_update_title(
@router.put("/{chat_id}")
async def handle_update_chat(
chat_id: str,
request: chat_schema.ChatUpdateTitleRequest,
request: chat_schema.ChatUpdateRequest,
ucr: deps.UserChatRepository = fastapi.Depends(deps.create_user_chat_repository),
) -> fastapi.Response:
"""Allows the client to update a chat title."""
"""Allows the client to update a chat."""
try:
await ucr.update_title(chat_id=chat_id, title=request.title)
await ucr.update_chat(
chat_id=chat_id,
title=request.title,
hidden=request.hidden,
)
except Exception:
logger.exception("Error when updating chat title")
logger.exception("Error when updating chat")
return fastapi.Response(status_code=500)


@router.put("/{chat_id}/hide")
async def update_visibility(
chat_id: str,
request: chat_schema.ChatUpdateVisibilityRequest,
ucr: UserChatRepository = Depends(deps.create_user_chat_repository),
):
await ucr.update_visibility(chat_id, request.hidden)
return fastapi.Response(status_code=200)
9 changes: 3 additions & 6 deletions inference/server/oasst_inference_server/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ def __init__(self, message: inference.MessageRead):
self.message = message


class ChatUpdateTitleRequest(pydantic.BaseModel):
title: pydantic.constr(max_length=100)


class ChatUpdateVisibilityRequest(pydantic.BaseModel):
hidden: bool
class ChatUpdateRequest(pydantic.BaseModel):
title: pydantic.constr(max_length=100) | None = None
hidden: bool | None = None
21 changes: 13 additions & 8 deletions inference/server/oasst_inference_server/user_chat_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,21 @@ async def add_report(self, message_id: str, reason: str, report_type: inference.
await self.session.refresh(report)
return report

async def update_title(self, chat_id: str, title: str) -> None:
logger.info(f"Updating title of chat {chat_id=}: {title=}")
async def update_chat(
self,
chat_id: str,
title: str | None = None,
hidden: bool | None = None,
) -> None:
logger.info(f"Updating chat {chat_id=}: {title=} {hidden=}")
chat = await self.get_chat_by_id(chat_id=chat_id, include_messages=False)

chat.title = title
await self.session.commit()
if title is not None:
logger.info(f"Updating title of chat {chat_id=}: {title=}")
chat.title = title

async def update_visibility(self, chat_id: str, hidden: bool) -> None:
logger.info(f"Setting chat {chat_id=} to {'hidden' if hidden else 'visible'}")
chat = await self.get_chat_by_id(chat_id=chat_id, include_messages=False)
if hidden is not None:
logger.info(f"Setting chat {chat_id=} to {'hidden' if hidden else 'visible'}")
chat.hidden = hidden

chat.hidden = hidden
await self.session.commit()

0 comments on commit 5f951b0

Please sign in to comment.