Skip to content

Commit

Permalink
Draft fix for updating search vectors on edit
Browse files Browse the repository at this point in the history
  • Loading branch information
olliestanley committed Jun 11, 2023
1 parent 9070e31 commit 85d3ada
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
3 changes: 3 additions & 0 deletions backend/oasst_backend/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __new__(cls, *args: Any, **kwargs: Any):
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))

search_vector: Optional[str] = Field(sa_column=sa.Column(pg.TSVECTOR(), nullable=True))
search_vector_update_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, index=True)
)

review_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
review_result: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=True))
Expand Down
4 changes: 3 additions & 1 deletion backend/oasst_backend/models/message_revision.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class MessageRevision(SQLModel, table=True):
message_id: UUID = Field(sa_column=sa.Column(sa.ForeignKey("message.id"), nullable=False, index=True))
user_id: Optional[UUID] = Field(sa_column=sa.Column(sa.ForeignKey("user.id"), nullable=True))
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp())
sa_column=sa.Column(
sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp(), index=True
)
)

_user_is_author: Optional[bool] = PrivateAttr(default=None)
29 changes: 26 additions & 3 deletions backend/oasst_backend/scheduled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from celery import shared_task
from loguru import logger
from oasst_backend.celery_worker import app
from oasst_backend.models import ApiClient, Message, User
from oasst_backend.models import ApiClient, Message, MessageRevision, User
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.database_utils import db_lang_to_postgres_ts_lang, default_session_factory
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_shared.utils import log_timing, utcnow
from sqlalchemy import func
from sqlmodel import update
from sqlmodel import and_, or_, update


async def useHFApi(text, url, model_name):
Expand Down Expand Up @@ -73,7 +73,29 @@ def update_search_vectors(batch_size: int) -> None:
with default_session_factory() as session:
while True:
to_update: list[Message] = (
session.query(Message).filter(Message.search_vector.is_(None)).limit(batch_size).all()
session.query(Message)
.outerjoin(
MessageRevision,
and_(
Message.id == MessageRevision.message_id,
MessageRevision.created_date
== session.query(func.max(MessageRevision.created_date))
.filter(MessageRevision.message_id == Message.id)
.as_scalar(),
),
)
.filter(
or_(
Message.search_vector.is_(None),
MessageRevision.created_date > Message.search_vector_update_date,
and_(
Message.search_vector_update_date.is_(None),
MessageRevision.created_date.isnot(None),
),
)
)
.limit(batch_size)
.all()
)

if not to_update:
Expand All @@ -83,6 +105,7 @@ def update_search_vectors(batch_size: int) -> None:
message_payload: MessagePayload = message.payload.payload
message_lang: str = db_lang_to_postgres_ts_lang(message.lang)
message.search_vector = func.to_tsvector(message_lang, message_payload.text)
message.search_vector_update_date = utcnow()

session.commit()
except Exception as e:
Expand Down

0 comments on commit 85d3ada

Please sign in to comment.