Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data management api #171

Merged
merged 6 commits into from
Dec 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
"""add deleted field to post

Revision ID: 8d269bc4fdbd
Revises: abb47e9d145a
Create Date: 2022-12-31 04:38:41.799206

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "8d269bc4fdbd"
down_revision = "abb47e9d145a"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "deleted")
# ### end Alembic commands ###
23 changes: 22 additions & 1 deletion backend/oasst_backend/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Generator
from uuid import UUID

from fastapi import Security
from fastapi import Depends, Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from loguru import logger
from oasst_backend.config import settings
Expand Down Expand Up @@ -64,3 +64,24 @@ def api_auth(
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
http_status_code=HTTPStatus.FORBIDDEN,
)


def get_api_client(
api_key: APIKey = Depends(get_api_key),
db: Session = Depends(get_db),
):
return api_auth(api_key, db)


def get_trusted_api_client(
api_key: APIKey = Depends(get_api_key),
db: Session = Depends(get_db),
):
client = api_auth(api_key, db)
if not client.trusted:
raise OasstError(
"Forbidden",
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
http_status_code=HTTPStatus.FORBIDDEN,
)
return client
7 changes: 6 additions & 1 deletion backend/oasst_backend/api/v1/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from oasst_backend.api.v1 import tasks, text_labels
from oasst_backend.api.v1 import frontend_messages, frontend_users, messages, stats, tasks, text_labels, users

api_router = APIRouter()
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
api_router.include_router(messages.router, prefix="/messages", tags=["messages"])
api_router.include_router(frontend_messages.router, prefix="/frontend_messages", tags=["frontend_messages"])
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
118 changes: 118 additions & 0 deletions backend/oasst_backend/api/v1/frontend_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-

from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session

router = APIRouter()


@router.get("/{message_id}")
def get_message_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a message by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)

if not isinstance(message.payload.payload, MessagePayload):
# Unexpected message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)

return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))


@router.get("/{message_id}/conversation")
def get_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a conversation from the tree root and up to the message with given frontend ID.
"""

pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_conversation(message)
return utils.prepare_conversation(messages)


@router.get("/{message_id}/tree")
def get_tree_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
Message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)


@router.get("/{message_id}/children")
def get_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_children(message.id)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]


@router.get("/{message_id}/descendants")
def get_descendants_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a subtree which starts with this message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)


@router.get("/{message_id}/longest_conversation_in_tree")
def get_longest_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get the longest conversation from the tree of the message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)


@router.get("/{message_id}/max_children_in_tree")
def get_max_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get message with the most children from the tree of the provided message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
60 changes: 60 additions & 0 deletions backend/oasst_backend/api/v1/frontend_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
import datetime
from uuid import UUID

from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.responses import Response
from starlette.status import HTTP_200_OK

router = APIRouter()


@router.get("/{username}/messages")
def query_frontend_user_messages(
username: str,
api_client_id: UUID = None,
max_count: int = Query(10, gt=0, le=1000),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
include_deleted: bool = False,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query frontend user messages.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
desc=desc,
limit=max_count,
start_date=start_date,
end_date=end_date,
only_roots=only_roots,
deleted=None if include_deleted else False,
)

return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]


@router.delete("/{username}/messages")
def mark_frontend_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
return Response(status_code=HTTP_200_OK)