Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add fastapi response model to every endpoints, add openapi documentation for API response #295

Merged
merged 1 commit into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 8 additions & 7 deletions backend/oasst_backend/api/v1/frontend_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session

router = APIRouter()


@router.get("/{message_id}")
@router.get("/{message_id}", response_model=protocol.Message)
def get_message_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -27,7 +28,7 @@ def get_message_by_frontend_id(
return utils.prepare_message(message)


@router.get("/{message_id}/conversation")
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
def get_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -41,7 +42,7 @@ def get_conv_by_frontend_id(
return utils.prepare_conversation(messages)


@router.get("/{message_id}/tree")
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
def get_tree_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -55,7 +56,7 @@ def get_tree_by_frontend_id(
return utils.prepare_tree(tree, message.message_tree_id)


@router.get("/{message_id}/children")
@router.get("/{message_id}/children", response_model=list[protocol.Message])
def get_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -68,7 +69,7 @@ def get_children_by_frontend_id(
return utils.prepare_message_list(messages)


@router.get("/{message_id}/descendants")
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
def get_descendants_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -82,7 +83,7 @@ def get_descendants_by_frontend_id(
return utils.prepare_tree(descendants, message.id)


@router.get("/{message_id}/longest_conversation_in_tree")
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
def get_longest_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -96,7 +97,7 @@ def get_longest_conv_by_frontend_id(
return utils.prepare_conversation(conv)


@router.get("/{message_id}/max_children_in_tree")
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
def get_max_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand Down
9 changes: 4 additions & 5 deletions backend/oasst_backend/api/v1/frontend_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.responses import Response
from starlette.status import HTTP_200_OK
from starlette.status import HTTP_204_NO_CONTENT

router = APIRouter()


@router.get("/{username}/messages")
@router.get("/{username}/messages", response_model=list[protocol.Message])
def query_frontend_user_messages(
username: str,
api_client_id: UUID = None,
Expand Down Expand Up @@ -43,11 +43,10 @@ def query_frontend_user_messages(
return utils.prepare_message_list(messages)


@router.delete("/{username}/messages")
@router.delete("/{username}/messages", status_code=HTTP_204_NO_CONTENT)
def mark_frontend_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
return Response(status_code=HTTP_200_OK)
24 changes: 12 additions & 12 deletions backend/oasst_backend/api/v1/messages.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import datetime
from uuid import UUID

from fastapi import APIRouter, Depends, Query, Response
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_200_OK
from starlette.status import HTTP_204_NO_CONTENT

router = APIRouter()


@router.get("/")
@router.get("/", response_model=list[protocol.Message])
def query_messages(
username: str = None,
api_client_id: str = None,
Expand Down Expand Up @@ -45,7 +46,7 @@ def query_messages(
return utils.prepare_message_list(messages)


@router.get("/{message_id}")
@router.get("/{message_id}", response_model=protocol.Message)
def get_message(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -61,7 +62,7 @@ def get_message(
return utils.prepare_message(message)


@router.get("/{message_id}/conversation")
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
def get_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -74,7 +75,7 @@ def get_conv(
return utils.prepare_conversation(messages)


@router.get("/{message_id}/tree")
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
def get_tree(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -87,7 +88,7 @@ def get_tree(
return utils.prepare_tree(tree, message.message_tree_id)


@router.get("/{message_id}/children")
@router.get("/{message_id}/children", response_model=list[protocol.Message])
def get_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -99,7 +100,7 @@ def get_children(
return utils.prepare_message_list(messages)


@router.get("/{message_id}/descendants")
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
def get_descendants(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -112,7 +113,7 @@ def get_descendants(
return utils.prepare_tree(descendants, message.id)


@router.get("/{message_id}/longest_conversation_in_tree")
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
def get_longest_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -125,7 +126,7 @@ def get_longest_conv(
return utils.prepare_conversation(conv)


@router.get("/{message_id}/max_children_in_tree")
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
def get_max_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
Expand All @@ -138,10 +139,9 @@ def get_max_children(
return utils.prepare_tree([message, *children], message.id)


@router.delete("/{message_id}")
@router.delete("/{message_id}", status_code=HTTP_204_NO_CONTENT)
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr.mark_messages_deleted(message_id)
return Response(status_code=HTTP_200_OK)
3 changes: 2 additions & 1 deletion backend/oasst_backend/api/v1/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session

router = APIRouter()


@router.get("/")
@router.get("/", response_model=protocol.SystemStats)
def get_message_stats(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
Expand Down
11 changes: 6 additions & 5 deletions backend/oasst_backend/api/v1/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT

router = APIRouter()

Expand Down Expand Up @@ -159,14 +160,14 @@ def request_task(
return task


@router.post("/{task_id}/ack", response_model=None)
@router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT)
def tasks_acknowledge(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
ack_request: protocol_schema.TaskAck,
) -> Any:
) -> None:
"""
The frontend acknowledges a task.
"""
Expand All @@ -187,14 +188,14 @@ def tasks_acknowledge(
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)


@router.post("/{task_id}/nack", response_model=None)
@router.post("/{task_id}/nack", response_model=None, status_code=HTTP_204_NO_CONTENT)
def tasks_acknowledge_failure(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
nack_request: protocol_schema.TaskNAck,
) -> Any:
) -> None:
"""
The frontend reports failure to implement a task.
"""
Expand Down Expand Up @@ -265,7 +266,7 @@ def tasks_interaction(
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)


@router.post("/close")
@router.post("/close", response_model=protocol_schema.TaskDone)
def close_collective_task(
close_task_request: protocol_schema.TaskClose,
db: Session = Depends(deps.get_db),
Expand Down
4 changes: 2 additions & 2 deletions backend/oasst_backend/api/v1/text_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST

router = APIRouter()

Expand All @@ -16,7 +16,7 @@ class LabelTextRequest(pydantic.BaseModel):
user: protocol_schema.User


@router.post("/")
@router.post("/", status_code=HTTP_204_NO_CONTENT)
def label_text(
*,
db: Session = Depends(deps.get_db),
Expand Down
16 changes: 5 additions & 11 deletions backend/oasst_backend/api/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.responses import Response
from starlette.status import HTTP_200_OK
from starlette.status import HTTP_204_NO_CONTENT

router = APIRouter()


@router.get("/{user_id}/messages")
@router.get("/{user_id}/messages", response_model=list[protocol.Message])
def query_user_messages(
user_id: UUID,
api_client_id: UUID = None,
Expand Down Expand Up @@ -41,19 +41,13 @@ def query_user_messages(
deleted=None if include_deleted else False,
)

return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
return utils.prepare_message_list(messages)


@router.delete("/{user_id}/messages")
@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT)
def mark_user_messages_deleted(
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)
return Response(status_code=HTTP_200_OK)