From 3e1047f6b9eb403be4f9038d70aaef8f0958293f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 16 Feb 2023 18:42:23 +0100 Subject: [PATCH] 1564 mitigate reply exploit (#1631) * limit number of pending tasks per user (within recent_tasks_span) * add trace warning * chane function name to fetch_pending_tasks_of_user() --- backend/oasst_backend/config.py | 4 +++ backend/oasst_backend/task_repository.py | 30 +++++++++++++++++-- backend/oasst_backend/tree_manager.py | 22 ++++++++++++++ .../exceptions/oasst_api_error.py | 1 + 4 files changed, 54 insertions(+), 3 deletions(-) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 205d52b80b..0f43284d8b 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -141,6 +141,10 @@ class TreeManagerConfiguration(BaseModel): recent_tasks_span_sec: int = 5 * 60 # 5 min """Time in seconds of recent tasks to consider for exclusion during task selection.""" + max_pending_tasks_per_user: int = 8 + """Maximum number of pending tasks (neither canceled nor completed) by a single user within + the time span defined by `recent_tasks_span_sec`.""" + class Settings(BaseSettings): PROJECT_NAME: str = "open-assistant backend" diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index 2df9efa45a..da476f437a 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -12,7 +12,7 @@ from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.utils import utcnow -from sqlmodel import Session, delete, false, func, or_ +from sqlmodel import Session, delete, false, func, not_, or_ from starlette.status import HTTP_404_NOT_FOUND @@ -222,10 +222,14 @@ def fetch_task_by_id(self, task_id: UUID) -> Task: return task def fetch_recent_reply_tasks( - self, max_age: timedelta = timedelta(minutes=5), done: bool = False, skipped: bool = False, limit: int = 100 + self, + max_age: timedelta = timedelta(minutes=5), + done: bool = False, + skipped: bool = False, + limit: int = 100, ) -> list[Task]: qry = self.db.query(Task).filter( - func.age(func.current_timestamp(), Task.created_date) < max_age, + Task.created_date > func.current_timestamp() - max_age, or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"), ) if done is not None: @@ -238,3 +242,23 @@ def fetch_recent_reply_tasks( def delete_expired(self) -> int: return delete_expired_tasks(self.db) + + def fetch_pending_tasks_of_user( + self, + user_id: UUID, + max_age: timedelta = timedelta(minutes=5), + limit: int = 100, + ) -> list[Task]: + qry = ( + self.db.query(Task) + .filter( + Task.user_id == user_id, + Task.created_date > func.current_timestamp() - max_age, + not_(Task.done), + not_(Task.skipped), + ) + .order_by(Task.created_date) + ) + if limit: + qry = qry.limit(limit) + return qry.all() diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index a31fcbdd91..ae36f9296b 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -378,6 +378,28 @@ def next_task( self._auto_moderation(lang=lang) num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=2) + # check user's pending tasks + recent_tasks_span = timedelta(seconds=self.cfg.recent_tasks_span_sec) + users_pending_tasks = self.pr.task_repository.fetch_pending_tasks_of_user( + self.pr.user_id, + max_age=recent_tasks_span, + limit=self.cfg.max_pending_tasks_per_user + 1, + ) + num_pending_tasks = len(users_pending_tasks) + if num_pending_tasks >= self.cfg.max_pending_tasks_per_user: + logger.warning( + f"Rejecting task request. User {self.pr.user_id} has {num_pending_tasks} pending tasks. " + f"Oldest age: {utcnow()-users_pending_tasks[0].created_date}." + ) + raise OasstError( + "User has too many pending tasks.", + OasstErrorCode.TASK_TOO_MANY_PENDING, + ) + elif num_pending_tasks > 0: + logger.debug( + f"User {self.pr.user_id} has {num_pending_tasks} pending tasks. Oldest age: {utcnow()-users_pending_tasks[0].created_date}" + ) + prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang) extendible_parents, active_tree_sizes = self.query_extendible_parents(lang=lang) diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index a13d06192f..29a38c7e48 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -43,6 +43,7 @@ class OasstErrorCode(IntEnum): TASK_MESSAGE_DUPLICATED = 1009 TASK_MESSAGE_TEXT_EMPTY = 1010 TASK_MESSAGE_DUPLICATE_REPLY = 1011 + TASK_TOO_MANY_PENDING = 1012 # 2000-3000: prompt_repository INVALID_FRONTEND_MESSAGE_ID = 2000