Skip to content

Commit

Permalink
adjust names and types to new naming
Browse files Browse the repository at this point in the history
  • Loading branch information
mjagkow committed Dec 31, 2022
1 parent b78bdfb commit 2f77590
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 81 deletions.
30 changes: 15 additions & 15 deletions backend/oasst_backend/api/v1/frontend_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from oasst_backend.api.v1 import utils
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import PostPayload
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
Expand All @@ -21,10 +21,10 @@ def get_message_by_frontend_id(
Get a message by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True)
message = pr.fetch_message_by_frontend_message_id(message_id)

if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)

return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))

Expand All @@ -38,7 +38,7 @@ def get_conv_by_frontend_id(
"""

pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_conversation(message)
return utils.prepare_conversation(messages)

Expand All @@ -52,9 +52,9 @@ def get_tree_by_frontend_id(
Message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
tree = pr.fetch_message_tree(message)
return utils.prepare_tree(tree, message.thread_id)
message = pr.fetch_message_by_frontend_message_id(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)


@router.get("/{message_id}/children")
Expand All @@ -65,7 +65,7 @@ def get_children_by_frontend_id(
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_children(message.id)
return [
protocol.Message(
Expand All @@ -84,8 +84,8 @@ def get_descendants_by_frontend_id(
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
descendants = pr.fetch_post_descendants(message)
message = pr.fetch_message_by_frontend_message_id(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)


Expand All @@ -98,8 +98,8 @@ def get_longest_conv_by_frontend_id(
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
conv = pr.fetch_longest_conversation(message.thread_id)
message = pr.fetch_message_by_frontend_message_id(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)


Expand All @@ -112,6 +112,6 @@ def get_max_children_by_frontend_id(
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
message, children = pr.fetch_message_with_max_children(message.thread_id)
message = pr.fetch_message_by_frontend_message_id(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
26 changes: 13 additions & 13 deletions backend/oasst_backend/api/v1/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from oasst_backend.api.v1 import utils
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import PostPayload
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
Expand Down Expand Up @@ -61,9 +61,9 @@ def get_message(
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
message = pr.fetch_message(message_id)
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)

return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))

Expand All @@ -89,9 +89,9 @@ def get_tree(
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
tree = pr.fetch_message_tree(message)
return utils.prepare_tree(tree, message.thread_id)
message = pr.fetch_message(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)


@router.get("/{message_id}/children")
Expand Down Expand Up @@ -119,8 +119,8 @@ def get_descendants(
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
descendants = pr.fetch_post_descendants(message)
message = pr.fetch_message(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)


Expand All @@ -132,8 +132,8 @@ def get_longest_conv(
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
conv = pr.fetch_longest_conversation(message.thread_id)
message = pr.fetch_message(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)


Expand All @@ -145,8 +145,8 @@ def get_max_children(
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
message, children = pr.fetch_message_with_max_children(message.thread_id)
message = pr.fetch_message(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)


Expand Down
12 changes: 6 additions & 6 deletions backend/oasst_backend/api/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from uuid import UUID

from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import Post
from oasst_backend.models.db_payload import PostPayload
from oasst_backend.models import Message
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.schemas import protocol


def prepare_conversation(messages: list[Post]) -> protocol.Conversation:
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
conv_messages = []
for message in messages:
if not isinstance(message.payload.payload, PostPayload):
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
conv_messages.append(
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
Expand All @@ -21,10 +21,10 @@ def prepare_conversation(messages: list[Post]) -> protocol.Conversation:
return protocol.Conversation(messages=conv_messages)


def prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree:
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
tree_messages = []
for message in tree:
if not isinstance(message.payload.payload, PostPayload):
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
tree_messages.append(
protocol.Message(
Expand Down
6 changes: 3 additions & 3 deletions backend/oasst_backend/journal_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, db: Session, api_client: ApiClient, user: User):
self.user = user
self.user_id = self.user.id if self.user else None

def log_text_reply(self, task: Task, message_id: UUID, role: str, length: int) -> Journal:
def log_text_reply(self, task: Task, message_id: Optional[UUID], role: str, length: int) -> Journal:
return self.log(
task_type=task.payload_type,
event_type=JournalEventType.text_reply_to_message,
Expand All @@ -63,7 +63,7 @@ def log_text_reply(self, task: Task, message_id: UUID, role: str, length: int) -
message_id=message_id,
)

def log_rating(self, task: Task, message_id: UUID, rating: int) -> Journal:
def log_rating(self, task: Task, message_id: Optional[UUID], rating: int) -> Journal:
return self.log(
task_type=task.payload_type,
event_type=JournalEventType.message_rating,
Expand All @@ -72,7 +72,7 @@ def log_rating(self, task: Task, message_id: UUID, rating: int) -> Journal:
message_id=message_id,
)

def log_ranking(self, task: Task, message_id: UUID, ranking: list[int]) -> Journal:
def log_ranking(self, task: Task, message_id: Optional[UUID], ranking: list[int]) -> Journal:
return self.log(
task_type=task.payload_type,
event_type=JournalEventType.message_ranking,
Expand Down

0 comments on commit 2f77590

Please sign in to comment.