Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BREAKING CHANGE: rename database objects #182

Merged
merged 14 commits into from
Dec 30, 2022
Merged

Large diffs are not rendered by default.

111 changes: 56 additions & 55 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def alembic_upgrade():

@app.on_event("startup")
def seed_data():
class DummyPost(pydantic.BaseModel):
task_post_id: str
user_post_id: str
parent_post_id: Optional[str]
class DummyMessage(pydantic.BaseModel):
task_message_id: str
user_message_id: str
parent_message_id: Optional[str]
text: str
role: str

Expand All @@ -81,96 +81,97 @@ class DummyPost(pydantic.BaseModel):
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)

dummy_posts = [
DummyPost(
task_post_id="de111fa8",
user_post_id="6f1d0711",
parent_post_id=None,
dummy_messages = [
DummyMessage(
task_message_id="de111fa8",
user_message_id="6f1d0711",
parent_message_id=None,
text="Hi!",
role="user",
andreaskoepf marked this conversation as resolved.
Show resolved Hide resolved
),
DummyPost(
task_post_id="74c381d4",
user_post_id="4a24530b",
parent_post_id="6f1d0711",
DummyMessage(
task_message_id="74c381d4",
user_message_id="4a24530b",
parent_message_id="6f1d0711",
text="Hello! How can I help you?",
role="assistant",
),
DummyPost(
task_post_id="3d5dc440",
user_post_id="a8c01c04",
parent_post_id="4a24530b",
DummyMessage(
task_message_id="3d5dc440",
user_message_id="a8c01c04",
parent_message_id="4a24530b",
text="Do you have a recipe for potato soup?",
role="user",
),
DummyPost(
task_post_id="643716c1",
user_post_id="f43a93b7",
parent_post_id="4a24530b",
DummyMessage(
task_message_id="643716c1",
user_message_id="f43a93b7",
parent_message_id="4a24530b",
text="Who were the 8 presidents before George Washington?",
role="user",
),
DummyPost(
task_post_id="2e4e1e6",
user_post_id="c886920",
parent_post_id="6f1d0711",
DummyMessage(
task_message_id="2e4e1e6",
user_message_id="c886920",
parent_message_id="6f1d0711",
text="Hey buddy! How can I serve you?",
role="assistant",
),
DummyPost(
task_post_id="970c437d",
user_post_id="cec432cf",
parent_post_id=None,
DummyMessage(
task_message_id="970c437d",
user_message_id="cec432cf",
parent_message_id=None,
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
role="user",
),
DummyPost(
task_post_id="6066118e",
user_post_id="4f85f637",
parent_post_id="cec432cf",
DummyMessage(
task_message_id="6066118e",
user_message_id="4f85f637",
parent_message_id="cec432cf",
text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
role="assistant",
),
DummyPost(
task_post_id="ba87780d",
user_post_id="0e276b98",
parent_post_id="cec432cf",
DummyMessage(
task_message_id="ba87780d",
user_message_id="0e276b98",
parent_message_id="cec432cf",
text="I'm unsure how to interpret this. Is it a riddle?",
role="assistant",
),
]

for p in dummy_posts:
wp = pr.fetch_workpackage_by_postid(p.task_post_id)
if wp and not wp.ack:
for msg in dummy_messages:
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
if task and not task.ack:
logger.warning("Deleting unacknowledged seed data work package")
andreaskoepf marked this conversation as resolved.
Show resolved Hide resolved
db.delete(wp)
wp = None
if not wp:
if p.parent_post_id is None:
wp = pr.store_task(
protocol_schema.InitialPromptTask(hint=""), thread_id=None, parent_post_id=None
db.delete(task)
task = None
if not task:
if msg.parent_message_id is None:
task = pr.store_task(
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
)
else:
print("p.parent_post_id", p.parent_post_id)
parent_post = pr.fetch_post_by_frontend_post_id(p.parent_post_id, fail_if_missing=True)
wp = pr.store_task(
parent_message = pr.fetch_message_by_frontend_message_id(
msg.parent_message_id, fail_if_missing=True
)
task = pr.store_task(
protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
)
),
thread_id=parent_post.thread_id,
parent_post_id=parent_post.id,
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
pr.bind_frontend_post_id(wp.id, p.task_post_id)
post = pr.store_text_reply(p.text, p.task_post_id, p.user_post_id)
pr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)

logger.info(
f"Inserted: post_id: {post.id}, payload: {post.payload.payload}, parent_post_id: {post.parent_id}"
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
)
else:
logger.debug(f"seed data work_package found: {wp.id}")
logger.debug(f"seed data task found: {task.id}")
logger.info("Seed data check completed")

except Exception:
Expand Down
86 changes: 46 additions & 40 deletions backend/oasst_backend/api/v1/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
def generate_task(
request: protocol_schema.TaskRequest, pr: PromptRepository
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
thread_id = None
parent_post_id = None
message_tree_id = None
parent_message_id = None

match request.type:
case protocol_schema.TaskRequestType.random:
Expand Down Expand Up @@ -56,36 +56,40 @@ def generate_task(
)
case protocol_schema.TaskRequestType.user_reply:
logger.info("Generating a UserReplyTask.")
posts = pr.fetch_random_conversation("assistant")
messages = [
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
for p in posts
messages = pr.fetch_random_conversation("assistant")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
)
for msg in messages
]

task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=messages))
thread_id = posts[-1].thread_id
parent_post_id = posts[-1].id
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.assistant_reply:
logger.info("Generating a AssistantReplyTask.")
posts = pr.fetch_random_conversation("user")
messages = [
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
for p in posts
messages = pr.fetch_random_conversation("user")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
)
for msg in messages
]

task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=messages))
thread_id = posts[-1].thread_id
parent_post_id = posts[-1].id
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.rank_initial_prompts:
logger.info("Generating a RankInitialPromptsTask.")

posts = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[p.payload.payload.text for p in posts])
messages = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
case protocol_schema.TaskRequestType.rank_user_replies:
logger.info("Generating a RankUserRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(post_role="assistant")
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")

messages = [
task_messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
is_assistant=(p.role == "assistant"),
Expand All @@ -95,16 +99,16 @@ def generate_task(
replies = [p.payload.payload.text for p in replies]
task = protocol_schema.RankUserRepliesTask(
conversation=protocol_schema.Conversation(
messages=messages,
messages=task_messages,
),
replies=replies,
)

case protocol_schema.TaskRequestType.rank_assistant_replies:
logger.info("Generating a RankAssistantRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(post_role="user")
conversation, replies = pr.fetch_multiple_random_replies(message_role="user")

messages = [
task_messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
is_assistant=(p.role == "assistant"),
Expand All @@ -113,15 +117,15 @@ def generate_task(
]
replies = [p.payload.payload.text for p in replies]
task = protocol_schema.RankAssistantRepliesTask(
conversation=protocol_schema.Conversation(messages=messages),
conversation=protocol_schema.Conversation(messages=task_messages),
replies=replies,
)
case _:
raise OasstError("Invalid request type", OasstErrorCode.TASK_INVALID_REQUEST_TYPE)

logger.info(f"Generated {task=}.")

return task, thread_id, parent_post_id
return task, message_tree_id, parent_message_id


@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
Expand All @@ -138,8 +142,8 @@ def request_task(

try:
pr = PromptRepository(db, api_client, request.user)
task, thread_id, parent_post_id = generate_task(request, pr)
pr.store_task(task, thread_id, parent_post_id, request.collective)
task, message_tree_id, parent_message_id = generate_task(request, pr)
pr.store_task(task, message_tree_id, parent_message_id, request.collective)

except OasstError:
raise
Expand All @@ -150,7 +154,7 @@ def request_task(


@router.post("/{task_id}/ack")
def acknowledge_task(
def tasks_acknowledge(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
Expand All @@ -166,9 +170,9 @@ def acknowledge_task(
try:
pr = PromptRepository(db, api_client, user=None)

# here we store the post id in the database for the task
# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)

except OasstError:
raise
Expand All @@ -179,7 +183,7 @@ def acknowledge_task(


@router.post("/{task_id}/nack")
def acknowledge_task_failure(
def tasks_acknowledge_failure(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
Expand All @@ -201,7 +205,7 @@ def acknowledge_task_failure(


@router.post("/interaction")
def post_interaction(
def tasks_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
Expand All @@ -216,29 +220,31 @@ def post_interaction(
pr = PromptRepository(db, api_client, user=interaction.user)

match type(interaction):
case protocol_schema.TextReplyToPost:
case protocol_schema.TextReplyToMessage:
logger.info(
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)

# here we store the text reply in the database
pr.store_text_reply(
text=interaction.text, post_id=interaction.post_id, user_post_id=interaction.user_post_id
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)

return protocol_schema.TaskDone()
case protocol_schema.PostRating:
case protocol_schema.MessageRating:
logger.info(
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
)

# here we store the rating in the database
pr.store_rating(interaction)

return protocol_schema.TaskDone()
case protocol_schema.PostRanking:
case protocol_schema.MessageRanking:
logger.info(
f"Frontend reports ranking of {interaction.post_id=} with {interaction.ranking=} by {interaction.user=}."
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
)

# TODO: check if the ranking is valid
Expand All @@ -262,5 +268,5 @@ def close_collective_task(
):
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.close_task(close_task_request.post_id)
pr.close_task(close_task_request.message_id)
return protocol_schema.TaskDone()
18 changes: 9 additions & 9 deletions backend/oasst_backend/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@ class OasstErrorCode(IntEnum):
TASK_GENERATION_FAILED = 1005

# 2000-3000: prompt_repository
INVALID_POST_ID = 2000
POST_NOT_FOUND = 2001
INVALID_FRONTEND_MESSAGE_ID = 2000
MESSAGE_NOT_FOUND = 2001
RATING_OUT_OF_RANGE = 2002
INVALID_RANKING_VALUE = 2003
INVALID_TASK_TYPE = 2004
USER_NOT_SPECIFIED = 2005
NO_THREADS_FOUND = 2006
andreaskoepf marked this conversation as resolved.
Show resolved Hide resolved
NO_REPLIES_FOUND = 2007
WORK_PACKAGE_NOT_FOUND = 2100
WORK_PACKAGE_EXPIRED = 2101
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
WORK_PACKAGE_ALREADY_UPDATED = 2103
WORK_PACKAGE_NOT_ACK = 2104
WORK_PACKAGE_ALREADY_DONE = 2105
WORK_PACKAGE_NOT_COLLECTIVE = 2106
TASK_NOT_FOUND = 2100
TASK_EXPIRED = 2101
TASK_PAYLOAD_TYPE_MISMATCH = 2102
TASK_ALREADY_UPDATED = 2103
TASK_NOT_ACK = 2104
TASK_ALREADY_DONE = 2105
TASK_NOT_COLLECTIVE = 2106


class OasstError(Exception):
Expand Down