Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for updating search vectors on edit #3382

Merged
merged 5 commits into from Jun 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,33 @@
"""Add search vector update date, indexes

Revision ID: d18213c629be
Revises: c181661eba3a
Create Date: 2023-06-11 15:44:49.911455

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "d18213c629be"
down_revision = "c181661eba3a"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("search_vector_update_date", sa.DateTime(timezone=True), nullable=True))
op.create_index(
op.f("ix_message_search_vector_update_date"), "message", ["search_vector_update_date"], unique=False
)
op.create_index(op.f("ix_message_revision_created_date"), "message_revision", ["created_date"], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_message_revision_created_date"), table_name="message_revision")
op.drop_index(op.f("ix_message_search_vector_update_date"), table_name="message")
op.drop_column("message", "search_vector_update_date")
# ### end Alembic commands ###
3 changes: 3 additions & 0 deletions backend/oasst_backend/models/message.py
Expand Up @@ -56,6 +56,9 @@ def __new__(cls, *args: Any, **kwargs: Any):
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))

search_vector: Optional[str] = Field(sa_column=sa.Column(pg.TSVECTOR(), nullable=True))
search_vector_update_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, index=True)
)

review_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
review_result: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=True))
Expand Down
4 changes: 3 additions & 1 deletion backend/oasst_backend/models/message_revision.py
Expand Up @@ -22,7 +22,9 @@ class MessageRevision(SQLModel, table=True):
message_id: UUID = Field(sa_column=sa.Column(sa.ForeignKey("message.id"), nullable=False, index=True))
user_id: Optional[UUID] = Field(sa_column=sa.Column(sa.ForeignKey("user.id"), nullable=True))
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp())
sa_column=sa.Column(
sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp(), index=True
)
)

_user_is_author: Optional[bool] = PrivateAttr(default=None)
41 changes: 37 additions & 4 deletions backend/oasst_backend/scheduled_tasks.py
Expand Up @@ -7,14 +7,14 @@
from celery import shared_task
from loguru import logger
from oasst_backend.celery_worker import app
from oasst_backend.models import ApiClient, Message, User
from oasst_backend.models import ApiClient, Message, MessageRevision, User
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.database_utils import db_lang_to_postgres_ts_lang, default_session_factory
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_shared.utils import log_timing, utcnow
from sqlalchemy import func
from sqlmodel import update
from sqlmodel import and_, or_, update


async def useHFApi(text, url, model_name):
Expand Down Expand Up @@ -72,17 +72,50 @@ def update_search_vectors(batch_size: int) -> None:
try:
with default_session_factory() as session:
while True:
to_update: list[Message] = (
session.query(Message).filter(Message.search_vector.is_(None)).limit(batch_size).all()
query = session.query(Message)

# Subquery to obtain creation date of most recent revision for a message
latest_revision_date_subquery = (
session.query(func.max(MessageRevision.created_date))
.filter(MessageRevision.message_id == Message.id)
.correlate(Message)
.as_scalar()
)

# Outerjoin messages to their most recent revisions
query = query.outerjoin(
MessageRevision,
and_(
Message.id == MessageRevision.message_id,
MessageRevision.created_date == latest_revision_date_subquery,
),
)

# Filter for only messages where we want to update the search vector
# The core components are when search vector is null, or there is a revision since last vector update
# We also add the case where is a revision and no vector update date
# This accounts for messages where the vector was generated before vector update dates were added
query = query.filter(
or_(
Message.search_vector.is_(None),
MessageRevision.created_date > Message.search_vector_update_date,
and_(
Message.search_vector_update_date.is_(None),
MessageRevision.created_date.isnot(None),
),
)
)

to_update: list[Message] = query.limit(batch_size).all()
melvinebenezer marked this conversation as resolved.
Show resolved Hide resolved

if not to_update:
break

for message in to_update:
message_payload: MessagePayload = message.payload.payload
message_lang: str = db_lang_to_postgres_ts_lang(message.lang)
message.search_vector = func.to_tsvector(message_lang, message_payload.text)
message.search_vector_update_date = utcnow()

session.commit()
except Exception as e:
Expand Down