diff --git a/backend/main.py b/backend/main.py index 5e86e45289..b309e38e6c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -21,7 +21,7 @@ from oasst_backend.models import message_tree_state from oasst_backend.prompt_repository import PromptRepository, UserRepository from oasst_backend.task_repository import TaskRepository, delete_expired_tasks -from oasst_backend.tree_manager import TreeManager +from oasst_backend.tree_manager import TreeManager, halt_prompts_of_disabled_users from oasst_backend.user_repository import User from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame from oasst_backend.utils.database_utils import CommitMode, managed_tx_function @@ -333,6 +333,7 @@ def update_user_streak(session: Session) -> None: @managed_tx_function(auto_commit=CommitMode.COMMIT) def cronjob_delete_expired_tasks(session: Session) -> None: delete_expired_tasks(session) + halt_prompts_of_disabled_users(session) @app.on_event("startup") diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 7dea2782d3..19ad972772 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -120,6 +120,33 @@ class TreeManagerStats(pydantic.BaseModel): message_counts: list[TreeMessageCountStats] +def halt_prompts_of_disabled_users(db: Session): + _sql_halt_prompts_of_disabled_users = """ +-- remove prompts of disabled & deleted users from prompt lottery +WITH cte AS ( +SELECT mts.message_tree_id +FROM message_tree_state mts +JOIN message m ON mts.message_tree_id = m.id +JOIN "user" u ON m.user_id = u.id +WHERE state = :prompt_lottery_waiting_state AND (NOT u.enabled OR u.deleted) +) +UPDATE message_tree_state mts2 +SET active=false, state=:halted_by_moderator_state +FROM cte +WHERE mts2.message_tree_id = cte.message_tree_id; +""" + + r = db.execute( + text(_sql_halt_prompts_of_disabled_users), + { + "prompt_lottery_waiting_state": message_tree_state.State.PROMPT_LOTTERY_WAITING, + "halted_by_moderator_state": message_tree_state.State.HALTED_BY_MODERATOR, + }, + ) + if r.rowcount > 0: + logger.info(f"Halted {r.rowcount} prompts of disabled users.") + + class TreeManager: def __init__( self, @@ -240,16 +267,20 @@ def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int: @managed_tx_function(CommitMode.COMMIT) def activate_one(db: Session) -> int: + # select among distinct users authors_qry = ( db.query(Message.user_id) .select_from(MessageTreeState) .join(Message, MessageTreeState.message_tree_id == Message.id) + .join(User, Message.user_id == User.id) .filter( MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING, Message.lang == lang, not_(Message.deleted), Message.review_result, + User.enabled, + not_(User.deleted), ) .distinct(Message.user_id) ) @@ -1309,6 +1340,8 @@ def ensure_tree_states(self) -> None: logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})") self._insert_default_state(id, state=state) + halt_prompts_of_disabled_users(self.db) + # check tree state transitions (maybe variables haves changes): prompt review -> growing -> ranking -> scoring prompt_review_trees: list[MessageTreeState] = ( self.db.query(MessageTreeState)