Skip to content

Commit

Permalink
Add ability to hide (soft delete) chats to inference backend (#2512)
Browse files Browse the repository at this point in the history
  • Loading branch information
olliestanley committed Apr 14, 2023
1 parent cf51432 commit 34ad7ec
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Add hidden field to chats
Revision ID: b66fd8f9da1f
Revises: f0e18084aae4
Create Date: 2023-04-14 16:11:35.361507
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "b66fd8f9da1f"
down_revision = "f0e18084aae4"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("chat", sa.Column("hidden", sa.Boolean(), server_default=sa.text("false"), nullable=False))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat", "hidden")
# ### end Alembic commands ###
4 changes: 4 additions & 0 deletions inference/server/oasst_inference_server/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,15 @@ class DbChat(SQLModel, table=True):

messages: list[DbMessage] = Relationship(back_populates="chat")

hidden: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false()))

def to_list_read(self) -> chat_schema.ChatListRead:
return chat_schema.ChatListRead(
id=self.id,
created_at=self.created_at,
modified_at=self.modified_at,
title=self.title,
hidden=self.hidden,
)

def to_read(self) -> chat_schema.ChatRead:
Expand All @@ -89,6 +92,7 @@ def to_read(self) -> chat_schema.ChatRead:
modified_at=self.modified_at,
title=self.title,
messages=[m.to_read() for m in self.messages],
hidden=self.hidden,
)

def get_msg_dict(self) -> dict[str, DbMessage]:
Expand Down
13 changes: 12 additions & 1 deletion inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

@router.get("")
async def list_chats(
include_hidden: bool = False,
ucr: UserChatRepository = Depends(deps.create_user_chat_repository),
) -> chat_schema.ListChatsResponse:
"""Lists all chats."""
logger.info("Listing all chats.")
chats = await ucr.get_chats()
chats = await ucr.get_chats(include_hidden=include_hidden)
chats_list = [chat.to_list_read() for chat in chats]
return chat_schema.ListChatsResponse(chats=chats_list)

Expand Down Expand Up @@ -270,3 +271,13 @@ async def handle_update_title(
except Exception:
logger.exception("Error when updating chat title")
return fastapi.Response(status_code=500)


@router.put("/{chat_id}/hide")
async def update_visibility(
chat_id: str,
hidden: bool,
ucr: UserChatRepository = Depends(deps.create_user_chat_repository),
):
await ucr.update_visibility(chat_id, hidden)
return fastapi.Response(status_code=200)
1 change: 1 addition & 0 deletions inference/server/oasst_inference_server/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ChatListRead(pydantic.BaseModel):
created_at: datetime.datetime
modified_at: datetime.datetime
title: str | None
hidden: bool = False


class ChatRead(ChatListRead):
Expand Down
13 changes: 11 additions & 2 deletions inference/server/oasst_inference_server/user_chat_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ class UserChatRepository(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True

async def get_chats(self) -> list[models.DbChat]:
async def get_chats(self, include_hidden: bool = False) -> list[models.DbChat]:
query = sqlmodel.select(models.DbChat)
query = query.where(models.DbChat.user_id == self.user_id)
if not include_hidden:
query = query.where(models.DbChat.hidden.is_(False))
query = query.order_by(models.DbChat.created_at.desc())
return (await self.session.exec(query)).all()

Expand Down Expand Up @@ -226,9 +228,16 @@ async def add_report(self, message_id: str, reason: str, report_type: inference.
await self.session.refresh(report)
return report

async def update_title(self, chat_id: str, title: str) -> models.DbChat:
async def update_title(self, chat_id: str, title: str) -> None:
logger.info(f"Updating title of chat {chat_id=}: {title=}")
chat = await self.get_chat_by_id(chat_id=chat_id, include_messages=False)

chat.title = title
await self.session.commit()

async def update_visibility(self, chat_id: str, hidden: bool) -> None:
logger.info(f"Setting chat {chat_id=} to {'hidden' if hidden else 'visible'}")
chat = await self.get_chat_by_id(chat_id=chat_id, include_messages=False)

chat.hidden = hidden
await self.session.commit()

0 comments on commit 34ad7ec

Please sign in to comment.