Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export average label values per message #1570

Merged
merged 5 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add text_labels message_id index

Revision ID: 165b55de5a94
Revises: ba40d055714a
Create Date: 2023-02-14 17:56:48.263684

"""
from alembic import op

# revision identifiers, used by Alembic.
revision = "165b55de5a94"
down_revision = "ba40d055714a"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(op.f("ix_text_labels_message_id"), "text_labels", ["message_id"], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_text_labels_message_id"), table_name="text_labels")
# ### end Alembic commands ###
152 changes: 107 additions & 45 deletions backend/export.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import argparse
from pathlib import Path
from typing import Dict, List, Optional
from typing import List, Optional
from uuid import UUID

import sqlalchemy as sa
from loguru import logger
from oasst_backend.database import engine
from oasst_backend.models import Message, MessageTreeState, TextLabels
from oasst_backend.models.message_tree_state import State as TreeState
from oasst_backend.utils import tree_export
from sqlmodel import Session, not_
from oasst_shared.schemas.protocol import TextLabel
from sqlmodel import Session, func, not_


def fetch_tree_ids(
Expand Down Expand Up @@ -60,17 +62,41 @@ def fetch_tree_messages(
return qry.all()


def fetch_message_labels(db: Session, messages: List[Message]) -> Dict[UUID, List[TextLabels]]:
message_ids = [message.id for message in messages]
def fetch_tree_messages_and_avg_labels(
db: Session,
message_tree_id: Optional[UUID] = None,
user_id: Optional[UUID] = None,
deleted: bool = None,
prompts_only: bool = False,
lang: Optional[str] = None,
review_result: Optional[bool] = None,
) -> List[Message]:

qry = db.query(TextLabels).filter(TextLabels.message_id.in_(message_ids))
args = [Message]

for l in TextLabel:
args.append(func.avg(TextLabels.labels[l].cast(sa.Float)).label(l.value))
args.append(func.count(TextLabels.labels[l]).label(l.value + "_count"))

qry = db.query(*args).select_from(Message).join(TextLabels, Message.id == TextLabels.message_id)
if message_tree_id:
qry = qry.filter(Message.message_tree_id == message_tree_id)
if user_id:
qry = qry.filter(Message.user_id == user_id)
if deleted is not None:
qry = qry.filter(Message.deleted == deleted)
if prompts_only:
qry = qry.filter(Message.parent_id.is_(None))
if lang:
qry = qry.filter(Message.lang == lang)
if review_result is False:
qry = qry.filter(not_(Message.review_result), Message.review_count > 2)
elif review_result is True:
qry = qry.filter(Message.review_result)

all_labels: List[TextLabels] = qry.all()
qry = qry.group_by(Message.id)

return {
message.id: [message_labels for message_labels in all_labels if message_labels.message_id == message.id]
for message in messages
}
return qry.all()


def export_trees(
Expand All @@ -85,10 +111,10 @@ def export_trees(
review_result: Optional[bool] = None,
export_labels: bool = False,
) -> None:
trees_to_export: List[tree_export.ExportMessageTree] = []

if user_id or review_result is False:
messages = fetch_tree_messages(
message_labels: dict[UUID, tree_export.LabelValues] = {}
if user_id:
# when filtering by user we don't have complete message trees, export as list
result = fetch_tree_messages_and_avg_labels(
db,
user_id=user_id,
deleted=deleted,
Expand All @@ -97,43 +123,79 @@ def export_trees(
review_result=review_result,
)

labels = fetch_message_labels(db, messages) if export_labels else None

tree_export.write_messages_to_file(export_file, messages, use_compression, labels=labels)
messages: list[Message] = []
for r in result:
msg = r["Message"]
messages.append(msg)
if export_labels:
labels: tree_export.LabelValues = {
l.value: tree_export.LabelAvgValue(value=r[l.value], count=r[l.value + "_count"])
for l in TextLabel
if r[l.value] is not None
}
message_labels[msg.id] = labels

tree_export.write_messages_to_file(export_file, messages, use_compression, labels=message_labels)
else:
message_tree_ids = fetch_tree_ids(db, state_filter, lang=lang)

message_trees = [
fetch_tree_messages(
db,
message_tree_id=tree_id,
deleted=deleted,
prompts_only=prompts_only,
lang=None,
review_result=review_result,
)
for (tree_id, _) in message_tree_ids
]

# when exporting only deleted we don't have a proper tree structure, export as list
if deleted is True:
messages = [m for t in message_trees for m in t]

labels = fetch_message_labels(db, messages) if export_labels else None

tree_export.write_messages_to_file(export_file, messages, use_compression, labels=labels)
else:
message_trees: list[list[Message]] = []

for tree_id, _ in message_tree_ids:
if export_labels:
all_messages = [m for t in message_trees for m in t]
labels = fetch_message_labels(db, all_messages)
result = fetch_tree_messages_and_avg_labels(
db,
message_tree_id=tree_id,
deleted=deleted,
prompts_only=prompts_only,
lang=None, # pass None here, trees were selected based on lang of prompt
review_result=review_result,
)

messages: list[Message] = []
for r in result:
msg = r["Message"]
messages.append(msg)
labels: tree_export.LabelValues = {
l.value: tree_export.LabelAvgValue(value=r[l.value], count=r[l.value + "_count"])
for l in TextLabel
if r[l.value] is not None
}
message_labels[msg.id] = labels

message_trees.append(messages)
else:
labels = None
messages = fetch_tree_messages(
db,
message_tree_id=tree_id,
deleted=deleted,
prompts_only=prompts_only,
lang=None, # pass None here, trees were selected based on lang of prompt
review_result=review_result,
)
message_trees.append(messages)

if review_result is False or deleted is True:
# when exporting filtered we don't have complete message trees, export as list
messages = [m for t in message_trees for m in t] # flatten message list
tree_export.write_messages_to_file(export_file, messages, use_compression, labels=message_labels)
else:
trees_to_export: List[tree_export.ExportMessageTree] = []

for (message_tree_id, message_tree_state), message_tree in zip(message_tree_ids, message_trees):
t = tree_export.build_export_tree(message_tree_id, message_tree_state, message_tree, labels=labels)
if prompts_only:
t.prompt.replies = None
trees_to_export.append(t)
if len(message_tree) > 0:
try:
t = tree_export.build_export_tree(
message_tree_id=message_tree_id,
message_tree_state=message_tree_state,
messages=message_tree,
labels=message_labels,
)
if prompts_only:
t.prompt.replies = None
trees_to_export.append(t)
except Exception as ex:
logger.warning(f"Corrupted tree: {message_tree_id} ({ex})")

tree_export.write_trees_to_file(export_file, trees_to_export, use_compression)

Expand Down Expand Up @@ -202,7 +264,7 @@ def parse_args():
parser.add_argument(
"--export-labels",
action="store_true",
help="Export labels for messages to a separate file",
help="Include average label values for messages",
)

args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion backend/oasst_backend/models/text_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TextLabels(SQLModel, table=True):
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
text: str = Field(nullable=False, max_length=2**16)
message_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=True)
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=True, index=True)
)
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
task_id: Optional[UUID] = Field(nullable=True, index=True)
41 changes: 17 additions & 24 deletions backend/oasst_backend/utils/tree_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@
from uuid import UUID

from fastapi.encoders import jsonable_encoder
from oasst_backend.models import Message, TextLabels
from oasst_backend.models import Message
from oasst_backend.models.message_tree_state import State as TreeState
from pydantic import BaseModel


class LabelAvgValue(BaseModel):
value: float | None
count: int


LabelValues = dict[str, LabelAvgValue]


class ExportMessageNode(BaseModel):
message_id: str
parent_id: str | None
Expand All @@ -27,10 +35,10 @@ class ExportMessageNode(BaseModel):
model_name: str | None
emojis: dict[str, int] | None
replies: list[ExportMessageNode] | None
labels: list[dict[str, float]] | None
labels: LabelValues | None

@staticmethod
def prep_message_export(message: Message) -> ExportMessageNode:
def prep_message_export(message: Message, labels: Optional[LabelValues] = None) -> ExportMessageNode:
return ExportMessageNode(
message_id=str(message.id),
parent_id=str(message.parent_id) if message.parent_id else None,
Expand All @@ -43,14 +51,9 @@ def prep_message_export(message: Message) -> ExportMessageNode:
model_name=message.model_name,
emojis=message.emojis,
rank=message.rank,
labels=labels,
)

@staticmethod
def prep_labelled_message_export(message: Message, labels: list[TextLabels]) -> ExportMessageNode:
node = ExportMessageNode.prep_message_export(message)
node.labels = [label.labels for label in labels]
return node


class ExportMessageTree(BaseModel):
message_tree_id: str
Expand All @@ -62,17 +65,9 @@ def build_export_tree(
message_tree_id: UUID,
message_tree_state: TreeState,
messages: list[Message],
labels: Optional[dict[UUID, list[TextLabels]]] = None,
labels: Optional[dict[UUID, LabelValues]] = None,
) -> ExportMessageTree:
if labels:
export_messages = [
ExportMessageNode.prep_labelled_message_export(m, labels[m.id])
if m.id in labels
else ExportMessageNode.prep_message_export(m)
for m in messages
]
else:
export_messages = [ExportMessageNode.prep_message_export(m) for m in messages]
export_messages = [ExportMessageNode.prep_message_export(m, labels.get(m.id) if labels else None) for m in messages]

messages_by_parent = defaultdict(list)
for message in export_messages:
Expand Down Expand Up @@ -125,7 +120,7 @@ def write_messages_to_file(
filename: str | None,
messages: Iterable[Message],
use_compression: bool = True,
labels: Optional[dict[UUID, list[TextLabels]]] = None,
labels: Optional[dict[UUID, LabelValues]] = None,
) -> None:
out_buff: TextIO

Expand All @@ -138,10 +133,8 @@ def write_messages_to_file(

with out_buff as f:
for m in messages:
if labels and m.id in labels:
export_message = ExportMessageNode.prep_labelled_message_export(m, labels[m.id])
else:
export_message = ExportMessageNode.prep_message_export(m)
export_message = ExportMessageNode.prep_message_export(m, labels.get(m.id) if labels else None)

file_data = jsonable_encoder(export_message, exclude_none=True)
json.dump(file_data, f)
f.write("\n")