diff --git a/inference/server/alembic/versions/2023_04_24_2130-401eef162771_add_chat_data_opt_out_field.py b/inference/server/alembic/versions/2023_04_24_2130-401eef162771_add_chat_data_opt_out_field.py new file mode 100644 index 0000000000..df7f4a41f1 --- /dev/null +++ b/inference/server/alembic/versions/2023_04_24_2130-401eef162771_add_chat_data_opt_out_field.py @@ -0,0 +1,27 @@ +"""Add chat data opt out field + +Revision ID: 401eef162771 +Revises: b66fd8f9da1f +Create Date: 2023-04-24 21:30:19.947411 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "401eef162771" +down_revision = "b66fd8f9da1f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("chat", sa.Column("allow_data_use", sa.Boolean(), server_default=sa.text("true"), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("chat", "allow_data_use") + # ### end Alembic commands ### diff --git a/inference/server/oasst_inference_server/models/chat.py b/inference/server/oasst_inference_server/models/chat.py index 490f16965a..9c565cb8e0 100644 --- a/inference/server/oasst_inference_server/models/chat.py +++ b/inference/server/oasst_inference_server/models/chat.py @@ -76,6 +76,8 @@ class DbChat(SQLModel, table=True): hidden: bool = Field(False, sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false())) + allow_data_use: bool = Field(True, sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true())) + def to_list_read(self) -> chat_schema.ChatListRead: return chat_schema.ChatListRead( id=self.id, diff --git a/inference/server/oasst_inference_server/routes/chats.py b/inference/server/oasst_inference_server/routes/chats.py index 68448bd02a..d336223bec 100644 --- a/inference/server/oasst_inference_server/routes/chats.py +++ b/inference/server/oasst_inference_server/routes/chats.py @@ -311,6 +311,7 @@ async def handle_update_chat( chat_id=chat_id, title=request.title, hidden=request.hidden, + allow_data_use=request.allow_data_use, ) except Exception: logger.exception("Error when updating chat") diff --git a/inference/server/oasst_inference_server/schemas/chat.py b/inference/server/oasst_inference_server/schemas/chat.py index 64ba2b94b5..653e677e89 100644 --- a/inference/server/oasst_inference_server/schemas/chat.py +++ b/inference/server/oasst_inference_server/schemas/chat.py @@ -89,3 +89,4 @@ def __init__(self, message: inference.MessageRead): class ChatUpdateRequest(pydantic.BaseModel): title: pydantic.constr(max_length=100) | None = None hidden: bool | None = None + allow_data_use: bool | None = None diff --git a/inference/server/oasst_inference_server/user_chat_repository.py b/inference/server/oasst_inference_server/user_chat_repository.py index fe5475457a..a760602fc3 100644 --- a/inference/server/oasst_inference_server/user_chat_repository.py +++ b/inference/server/oasst_inference_server/user_chat_repository.py @@ -275,6 +275,7 @@ async def update_chat( chat_id: str, title: str | None = None, hidden: bool | None = None, + allow_data_use: bool | None = None, ) -> None: logger.info(f"Updating chat {chat_id=}: {title=} {hidden=}") chat = await self.get_chat_by_id(chat_id=chat_id, include_messages=False) @@ -287,4 +288,8 @@ async def update_chat( logger.info(f"Setting chat {chat_id=} to {'hidden' if hidden else 'visible'}") chat.hidden = hidden + if allow_data_use is not None: + logger.info(f"Updating allow_data_use of chat {chat_id=}: {allow_data_use=}") + chat.allow_data_use = allow_data_use + await self.session.commit()