Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1564 mitigate reply exploit #1631

Merged
merged 4 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/oasst_backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ class TreeManagerConfiguration(BaseModel):
recent_tasks_span_sec: int = 5 * 60 # 5 min
"""Time in seconds of recent tasks to consider for exclusion during task selection."""

max_pending_tasks_per_user: int = 8
"""Maximum number of pending tasks (neither canceled nor completed) by a single user within
the time span defined by `recent_tasks_span_sec`."""


class Settings(BaseSettings):
PROJECT_NAME: str = "open-assistant backend"
Expand Down
30 changes: 27 additions & 3 deletions backend/oasst_backend/task_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.utils import utcnow
from sqlmodel import Session, delete, false, func, or_
from sqlmodel import Session, delete, false, func, not_, or_
from starlette.status import HTTP_404_NOT_FOUND


Expand Down Expand Up @@ -222,10 +222,14 @@ def fetch_task_by_id(self, task_id: UUID) -> Task:
return task

def fetch_recent_reply_tasks(
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, skipped: bool = False, limit: int = 100
self,
max_age: timedelta = timedelta(minutes=5),
done: bool = False,
skipped: bool = False,
limit: int = 100,
) -> list[Task]:
qry = self.db.query(Task).filter(
func.age(func.current_timestamp(), Task.created_date) < max_age,
Task.created_date > func.current_timestamp() - max_age,
or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
)
if done is not None:
Expand All @@ -238,3 +242,23 @@ def fetch_recent_reply_tasks(

def delete_expired(self) -> int:
return delete_expired_tasks(self.db)

def fetch_pending_tasks_of_user(
self,
user_id: UUID,
max_age: timedelta = timedelta(minutes=5),
limit: int = 100,
) -> list[Task]:
qry = (
self.db.query(Task)
.filter(
Task.user_id == user_id,
Task.created_date > func.current_timestamp() - max_age,
not_(Task.done),
not_(Task.skipped),
)
.order_by(Task.created_date)
)
if limit:
qry = qry.limit(limit)
return qry.all()
22 changes: 22 additions & 0 deletions backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,28 @@ def next_task(
self._auto_moderation(lang=lang)
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=2)

# check user's pending tasks
recent_tasks_span = timedelta(seconds=self.cfg.recent_tasks_span_sec)
users_pending_tasks = self.pr.task_repository.fetch_pending_tasks_of_user(
self.pr.user_id,
max_age=recent_tasks_span,
limit=self.cfg.max_pending_tasks_per_user + 1,
)
num_pending_tasks = len(users_pending_tasks)
if num_pending_tasks >= self.cfg.max_pending_tasks_per_user:
logger.warning(
f"Rejecting task request. User {self.pr.user_id} has {num_pending_tasks} pending tasks. "
f"Oldest age: {utcnow()-users_pending_tasks[0].created_date}."
)
raise OasstError(
"User has too many pending tasks.",
OasstErrorCode.TASK_TOO_MANY_PENDING,
)
elif num_pending_tasks > 0:
logger.debug(
f"User {self.pr.user_id} has {num_pending_tasks} pending tasks. Oldest age: {utcnow()-users_pending_tasks[0].created_date}"
)

prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
extendible_parents, active_tree_sizes = self.query_extendible_parents(lang=lang)
Expand Down
1 change: 1 addition & 0 deletions oasst-shared/oasst_shared/exceptions/oasst_api_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class OasstErrorCode(IntEnum):
TASK_MESSAGE_DUPLICATED = 1009
TASK_MESSAGE_TEXT_EMPTY = 1010
TASK_MESSAGE_DUPLICATE_REPLY = 1011
TASK_TOO_MANY_PENDING = 1012

# 2000-3000: prompt_repository
INVALID_FRONTEND_MESSAGE_ID = 2000
Expand Down