Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lang-tag based task selection (lang-separation) #863

Merged
merged 4 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""use 'en' instead 'en-US' as default lang

Revision ID: 160ac010efcc
Revises: 7f0a28a156f4
Create Date: 2023-01-20 14:54:09.168217

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "160ac010efcc"
down_revision = "7f0a28a156f4"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "lang")
op.add_column("message", sa.Column("lang", sa.String(length=32), server_default="en", nullable=False))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "lang")
op.add_column("message", sa.Column("lang", sa.VARCHAR(length=200), autoincrement=False, nullable=False))
# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class DummyMessage(BaseModel):
user_message_id: str
parent_message_id: Optional[str]
text: str
lang: Optional[str]
role: str
tree_state: Optional[message_tree_state.State]

Expand Down Expand Up @@ -184,6 +185,7 @@ class DummyMessage(BaseModel):
tr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(
msg.text,
msg.lang,
msg.task_message_id,
msg.user_message_id,
review_count=5,
Expand Down
5 changes: 3 additions & 2 deletions backend/oasst_backend/api/v1/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def request_task(
pr.ensure_user_is_enabled()

tm = TreeManager(db, pr)
task, message_tree_id, parent_message_id = tm.next_task(request.type)
task, message_tree_id, parent_message_id = tm.next_task(desired_task_type=request.type, lang=request.lang)
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)

except OasstError:
Expand All @@ -54,6 +54,7 @@ def request_task(
def tasks_availability(
*,
user: Optional[protocol_schema.User] = None,
lang: Optional[str] = "en",
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
):
Expand All @@ -62,7 +63,7 @@ def tasks_availability(
try:
pr = PromptRepository(db, api_client, client_user=user)
tm = TreeManager(db, pr)
return tm.determine_task_availability()
return tm.determine_task_availability(lang)

except OasstError:
raise
Expand Down
6 changes: 4 additions & 2 deletions backend/oasst_backend/api/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def prepare_message(m: Message) -> protocol.Message:
frontend_message_id=m.frontend_message_id,
parent_id=m.parent_id,
text=m.text,
lang=m.lang,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
)
Expand All @@ -22,10 +23,11 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.ConversationMessage]:
return [
protocol.ConversationMessage(
text=message.text,
is_assistant=(message.role == "assistant"),
id=message.id,
frontend_message_id=message.frontend_message_id,
text=message.text,
lang=message.lang,
is_assistant=(message.role == "assistant"),
)
for message in messages
]
Expand Down
2 changes: 1 addition & 1 deletion backend/oasst_backend/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Message(SQLModel, table=True):
payload: Optional[PayloadContainer] = Field(
sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)
)
lang: str = Field(nullable=False, max_length=200, default="en-US")
lang: str = Field(sa_column=sa.Column(sa.String(32), server_default="en", nullable=False))
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
Expand Down
4 changes: 4 additions & 0 deletions backend/oasst_backend/prompt_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def insert_message(
task_id: UUID,
role: str,
payload: db_payload.MessagePayload,
lang: str,
payload_type: str = None,
depth: int = 0,
review_count: int = 0,
Expand All @@ -107,6 +108,7 @@ def insert_message(
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
lang=lang,
depth=depth,
review_count=review_count,
review_result=review_result,
Expand Down Expand Up @@ -146,6 +148,7 @@ def fetch_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
def store_text_reply(
self,
text: str,
lang: str,
frontend_message_id: str,
user_frontend_message_id: str,
review_count: int = 0,
Expand Down Expand Up @@ -209,6 +212,7 @@ def store_text_reply(
task_id=task.id,
role=role,
payload=db_payload.MessagePayload(text=text),
lang=lang or "en",
depth=depth,
review_count=review_count,
review_result=review_result,
Expand Down
64 changes: 43 additions & 21 deletions backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,18 @@ def _determine_task_availability_internal(

return task_count_by_type

def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]:
self.pr.ensure_user_is_enabled()

num_active_trees = self.query_num_active_trees()
extendible_parents = self.query_extendible_parents()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
incomplete_rankings = self.query_incomplete_rankings()
if not lang:
lang = "en"
logger.warning("Task availability request without lang tag received, assuming lang='en'.")

num_active_trees = self.query_num_active_trees(lang=lang)
extendible_parents = self.query_extendible_parents(lang=lang)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
incomplete_rankings = self.query_incomplete_rankings(lang=lang)

return self._determine_task_availability_internal(
num_active_trees=num_active_trees,
Expand All @@ -208,23 +212,29 @@ def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, i
)

def next_task(
self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random
self,
desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random,
lang: str = "en",
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:

logger.debug("TreeManager.next_task()")
logger.debug(f"TreeManager.next_task({desired_task_type=}, {lang=})")

self.pr.ensure_user_is_enabled()

num_active_trees = self.query_num_active_trees()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
extendible_parents = self.query_extendible_parents()
if not lang:
lang = "en"
logger.warning("Task request without lang tag received, assuming 'en'.")

incomplete_rankings = self.query_incomplete_rankings()
num_active_trees = self.query_num_active_trees(lang=lang)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
extendible_parents = self.query_extendible_parents(lang=lang)

incomplete_rankings = self.query_incomplete_rankings(lang=lang)
if not self.cfg.rank_prompter_replies:
incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings))

active_tree_sizes = self.query_extendible_trees()
active_tree_sizes = self.query_extendible_trees(lang=lang)

# determine type of task to generate
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
Expand Down Expand Up @@ -458,6 +468,7 @@ async def handle_interaction(self, interaction: protocol_schema.AnyInteraction)
# here we store the text reply in the database
message = pr.store_text_reply(
text=interaction.text,
lang=interaction.lang,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)
Expand Down Expand Up @@ -665,7 +676,7 @@ def _calculate_acceptance(self, labels: list[TextLabels]):
# calculate acceptance based on spam label
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])

def query_prompts_need_review(self) -> list[Message]:
def query_prompts_need_review(self, lang: str) -> list[Message]:
"""
Select initial prompt messages with less then required rankings in active message tree
(active == True in message_tree_state)
Expand All @@ -682,6 +693,7 @@ def query_prompts_need_review(self) -> list[Message]:
not_(Message.deleted),
Message.review_count < self.cfg.num_reviews_initial_prompt,
Message.parent_id.is_(None),
Message.lang == lang,
)
)

Expand All @@ -690,7 +702,7 @@ def query_prompts_need_review(self) -> list[Message]:

return qry.all()

def query_replies_need_review(self) -> list[Message]:
def query_replies_need_review(self, lang: str) -> list[Message]:
"""
Select child messages (parent_id IS NOT NULL) with less then required rankings
in active message tree (active == True in message_tree_state)
Expand All @@ -707,6 +719,7 @@ def query_replies_need_review(self) -> list[Message]:
not_(Message.deleted),
Message.review_count < self.cfg.num_reviews_reply,
Message.parent_id.is_not(None),
Message.lang == lang,
)
)

Expand All @@ -724,20 +737,22 @@ def query_replies_need_review(self) -> list[Message]:
WHERE mts.active -- only consider active trees
AND mts.state = :ranking_state -- message tree must be in ranking state
AND m.review_result -- must be reviewed
AND m.lang = :lang -- matches lang
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
GROUP BY m.parent_id, m.role
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
"""

def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]:
def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
"""Query parents which have childern that need further rankings"""

r = self.db.execute(
text(self._sql_find_incomplete_rankings),
{
"num_required_rankings": self.cfg.num_required_rankings,
"ranking_state": message_tree_state.State.RANKING,
"lang": lang,
},
)
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
Expand All @@ -753,20 +768,22 @@ def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]:
AND NOT m.deleted -- ignore deleted messages as parents
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
AND m.review_result -- parent node must have positive review
AND m.lang = :lang -- parent matches lang
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
"""

def query_extendible_parents(self) -> list[ExtendibleParentRow]:
def query_extendible_parents(self, lang: str) -> list[ExtendibleParentRow]:
"""Query parent messages that have not reached the maximum number of replies."""

r = self.db.execute(
text(self._sql_find_extendible_parents),
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
"lang": lang,
},
)
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
Expand All @@ -787,14 +804,15 @@ def query_extendible_parents(self) -> list[ExtendibleParentRow]:
HAVING COUNT(m.id) < mts.goal_tree_size
"""

def query_extendible_trees(self) -> list[ActiveTreeSizeRow]:
def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]:
"""Query size of active message trees in growing state."""

r = self.db.execute(
text(self._sql_find_extendible_trees),
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
"lang": lang,
},
)
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
Expand Down Expand Up @@ -894,8 +912,12 @@ def ensure_tree_states(self):
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
self._insert_default_state(id, state=state)

def query_num_active_trees(self) -> int:
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active)
def query_num_active_trees(self, lang: str) -> int:
query = (
self.db.query(func.count(MessageTreeState.message_tree_id))
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(MessageTreeState.active, Message.lang == lang)
)
return query.scalar()

def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
Expand Down
3 changes: 3 additions & 0 deletions oasst-shared/oasst_shared/schemas/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ConversationMessage(BaseModel):
id: Optional[UUID] = None
frontend_message_id: Optional[str] = None
text: str
lang: Optional[str] # BCP 47
is_assistant: bool


Expand Down Expand Up @@ -72,6 +73,7 @@ class TaskRequest(BaseModel):
# this is optional. https://github.com/pydantic/pydantic/issues/1270
user: Optional[User] = Field(None, nullable=True)
collective: bool = False
lang: Optional[str] # BCP 47


class TaskAck(BaseModel):
Expand Down Expand Up @@ -266,6 +268,7 @@ class TextReplyToMessage(Interaction):
message_id: str
user_message_id: str
text: constr(min_length=1, strip_whitespace=True)
lang: Optional[str] # BCP 47


class MessageRating(Interaction):
Expand Down
1 change: 1 addition & 0 deletions oasst-shared/tests/test_oasst_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient):
message_id="123",
user_message_id="321",
text="This is my reply",
lang="en",
user=protocol_schema.User(
id="123",
display_name="lomz",
Expand Down