Skip to content

Commit

Permalink
1564 mitigate reply exploit (#1631)
Browse files Browse the repository at this point in the history
* limit number of pending tasks per user (within recent_tasks_span)

* add trace warning

* chane function name to fetch_pending_tasks_of_user()
  • Loading branch information
andreaskoepf committed Feb 16, 2023
1 parent 44fe6e6 commit 3e1047f
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 3 deletions.
4 changes: 4 additions & 0 deletions backend/oasst_backend/config.py
Expand Up @@ -141,6 +141,10 @@ class TreeManagerConfiguration(BaseModel):
recent_tasks_span_sec: int = 5 * 60 # 5 min
"""Time in seconds of recent tasks to consider for exclusion during task selection."""

max_pending_tasks_per_user: int = 8
"""Maximum number of pending tasks (neither canceled nor completed) by a single user within
the time span defined by `recent_tasks_span_sec`."""


class Settings(BaseSettings):
PROJECT_NAME: str = "open-assistant backend"
Expand Down
30 changes: 27 additions & 3 deletions backend/oasst_backend/task_repository.py
Expand Up @@ -12,7 +12,7 @@
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.utils import utcnow
from sqlmodel import Session, delete, false, func, or_
from sqlmodel import Session, delete, false, func, not_, or_
from starlette.status import HTTP_404_NOT_FOUND


Expand Down Expand Up @@ -222,10 +222,14 @@ def fetch_task_by_id(self, task_id: UUID) -> Task:
return task

def fetch_recent_reply_tasks(
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, skipped: bool = False, limit: int = 100
self,
max_age: timedelta = timedelta(minutes=5),
done: bool = False,
skipped: bool = False,
limit: int = 100,
) -> list[Task]:
qry = self.db.query(Task).filter(
func.age(func.current_timestamp(), Task.created_date) < max_age,
Task.created_date > func.current_timestamp() - max_age,
or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
)
if done is not None:
Expand All @@ -238,3 +242,23 @@ def fetch_recent_reply_tasks(

def delete_expired(self) -> int:
return delete_expired_tasks(self.db)

def fetch_pending_tasks_of_user(
self,
user_id: UUID,
max_age: timedelta = timedelta(minutes=5),
limit: int = 100,
) -> list[Task]:
qry = (
self.db.query(Task)
.filter(
Task.user_id == user_id,
Task.created_date > func.current_timestamp() - max_age,
not_(Task.done),
not_(Task.skipped),
)
.order_by(Task.created_date)
)
if limit:
qry = qry.limit(limit)
return qry.all()
22 changes: 22 additions & 0 deletions backend/oasst_backend/tree_manager.py
Expand Up @@ -378,6 +378,28 @@ def next_task(
self._auto_moderation(lang=lang)
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=2)

# check user's pending tasks
recent_tasks_span = timedelta(seconds=self.cfg.recent_tasks_span_sec)
users_pending_tasks = self.pr.task_repository.fetch_pending_tasks_of_user(
self.pr.user_id,
max_age=recent_tasks_span,
limit=self.cfg.max_pending_tasks_per_user + 1,
)
num_pending_tasks = len(users_pending_tasks)
if num_pending_tasks >= self.cfg.max_pending_tasks_per_user:
logger.warning(
f"Rejecting task request. User {self.pr.user_id} has {num_pending_tasks} pending tasks. "
f"Oldest age: {utcnow()-users_pending_tasks[0].created_date}."
)
raise OasstError(
"User has too many pending tasks.",
OasstErrorCode.TASK_TOO_MANY_PENDING,
)
elif num_pending_tasks > 0:
logger.debug(
f"User {self.pr.user_id} has {num_pending_tasks} pending tasks. Oldest age: {utcnow()-users_pending_tasks[0].created_date}"
)

prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
extendible_parents, active_tree_sizes = self.query_extendible_parents(lang=lang)
Expand Down
1 change: 1 addition & 0 deletions oasst-shared/oasst_shared/exceptions/oasst_api_error.py
Expand Up @@ -43,6 +43,7 @@ class OasstErrorCode(IntEnum):
TASK_MESSAGE_DUPLICATED = 1009
TASK_MESSAGE_TEXT_EMPTY = 1010
TASK_MESSAGE_DUPLICATE_REPLY = 1011
TASK_TOO_MANY_PENDING = 1012

# 2000-3000: prompt_repository
INVALID_FRONTEND_MESSAGE_ID = 2000
Expand Down

0 comments on commit 3e1047f

Please sign in to comment.