Skip to content

Commit

Permalink
add admin purge user function (#834)
Browse files Browse the repository at this point in the history
* add admin purge user function

* improve comments

* minor naming changes

* ensuer user is enabled for tasks api requests

* add preview with stats to /admin/purge_user/{id} endpoint

* add update_children_counts()
  • Loading branch information
andreaskoepf committed Jan 19, 2023
1 parent 70febaa commit 335af5d
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 12 deletions.
52 changes: 52 additions & 0 deletions backend/oasst_backend/api/v1/admin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from uuid import UUID

import pydantic
from fastapi import APIRouter, Depends
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.models.api_client import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.tree_manager import TreeManager
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared.schemas.protocol import SystemStats
from oasst_shared.utils import ScopeTimer

router = APIRouter()

Expand Down Expand Up @@ -29,3 +37,47 @@ async def create_api_client(
)
logger.info(f"Created api_client with key {api_client.api_key}")
return api_client.api_key


class PurgeResultModel(pydantic.BaseModel):
before: SystemStats
after: SystemStats
preview: bool
duration: float


@router.post("/purge_user/{user_id}", response_model=PurgeResultModel)
async def purge_user(
user_id: UUID,
preview: bool = True,
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> str:
assert api_client.trusted

@managed_tx_function(CommitMode.NONE if preview else CommitMode.COMMIT)
def purge_tx(session: deps.Session):
pr = PromptRepository(session, api_client)

stats_before = pr.get_stats()

user = pr.user_repository.get_user(user_id)
tm = TreeManager(session, pr)
tm.purge_user(user_id)

return user, stats_before, pr.get_stats()

timer = ScopeTimer()
user, before, after = purge_tx()
timer.stop()

if preview:
logger.info(
f"PURGE USER PREVIEW: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
)
else:
logger.warning(
f"PURGE USER: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
)

logger.info(f"{before=}; {after=}")
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
4 changes: 4 additions & 0 deletions backend/oasst_backend/api/v1/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def request_task(

try:
pr = PromptRepository(db, api_client, client_user=request.user)
pr.ensure_user_is_enabled()

tm = TreeManager(db, pr)
task, message_tree_id, parent_message_id = tm.next_task(request.type)
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
Expand Down Expand Up @@ -85,6 +87,7 @@ def tasks_acknowledge(

try:
pr = PromptRepository(db, api_client)
pr.ensure_user_is_enabled()

# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
Expand Down Expand Up @@ -113,6 +116,7 @@ def tasks_acknowledge_failure(
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client)
pr.ensure_user_is_enabled()
pr.task_repository.acknowledge_task_failure(task_id)
except (KeyError, RuntimeError):
logger.exception("Failed to not acknowledge task.")
Expand Down
32 changes: 27 additions & 5 deletions backend/oasst_backend/prompt_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
from sqlalchemy import update
from sqlmodel import Session, func
from sqlmodel import Session, func, text, update
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND


Expand All @@ -53,6 +52,13 @@ def __init__(
)
self.journal = JournalWriter(db, api_client, self.user)

def ensure_user_is_enabled(self):
if self.user is None or self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)

if self.user.deleted or not self.user.enabled:
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)

def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
validate_frontend_message_id(frontend_message_id)
message: Message = (
Expand Down Expand Up @@ -146,6 +152,8 @@ def store_text_reply(
review_result: bool = False,
check_tree_state: bool = True,
) -> Message:
self.ensure_user_is_enabled()

validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)

Expand Down Expand Up @@ -354,8 +362,7 @@ def insert_message_embedding(self, message_id: UUID, model: str, embedding: List

@managed_tx_method(CommitMode.FLUSH)
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
self.ensure_user_is_enabled()

container = PayloadContainer(payload=payload)
reaction = MessageReaction(
Expand Down Expand Up @@ -499,7 +506,7 @@ def fetch_random_initial_prompts(self, size: int = 5):
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
return messages

def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True):
def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True) -> list[Message]:
qry = self.db.query(Message).filter(Message.message_tree_id == message_tree_id)
if reviewed:
qry = qry.filter(Message.review_result)
Expand Down Expand Up @@ -702,6 +709,21 @@ def query_messages(

return messages.all()

def update_children_counts(self, message_tree_id: UUID):
sql_update_children_count = """
UPDATE message SET children_count = cc.children_count
FROM (
SELECT m.id, count(c.id) - COALESCE(SUM(c.deleted::int), 0) AS children_count
FROM message m
LEFT JOIN message c ON m.id = c.parent_id
WHERE m.message_tree_id = :message_tree_id
GROUP BY m.id
) AS cc
WHERE message.id = cc.id;
"""
r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.")

@managed_tx_method(CommitMode.COMMIT)
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
"""
Expand Down
142 changes: 140 additions & 2 deletions backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from oasst_backend.utils.ranking import ranked_pairs
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlalchemy.sql import text
from sqlmodel import Session, func, not_
from sqlmodel import Session, func, not_, text


class TaskType(Enum):
Expand Down Expand Up @@ -192,6 +191,8 @@ def _determine_task_availability_internal(
return task_count_by_type

def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
self.pr.ensure_user_is_enabled()

num_active_trees = self.query_num_active_trees()
extendible_parents = self.query_extendible_parents()
prompts_need_review = self.query_prompts_need_review()
Expand All @@ -212,6 +213,8 @@ def next_task(

logger.debug("TreeManager.next_task()")

self.pr.ensure_user_is_enabled()

num_active_trees = self.query_num_active_trees()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
Expand Down Expand Up @@ -445,6 +448,7 @@ def next_task(
@async_managed_tx_method(CommitMode.COMMIT)
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
pr = self.pr
pr.ensure_user_is_enabled()
match type(interaction):
case protocol_schema.TextReplyToMessage:
logger.info(
Expand Down Expand Up @@ -978,6 +982,136 @@ def stats(self) -> TreeManagerStats:
message_counts=self.tree_message_count_stats(only_active=True),
)

def get_user_messages_by_tree(self, user_id: UUID) -> Tuple[dict[UUID, list[Message]], list[Message]]:
"""Returns a dict with replies by tree (excluding initial prompts) and list of initial prompts
associated with user_id."""

# query all messages of the user
qry = self.db.query(Message).filter(Message.user_id == user_id)

prompts: list[Message] = []
replies_by_tree: dict[UUID, list[Message]] = {}

# walk over result set and distinguish between initial prompts and replies
for m in qry:
m: Message

if m.message_tree_id == m.id:
prompts.append(m)
else:
message_list = replies_by_tree.get(m.message_tree_id)
if message_list is None:
message_list = [m]
replies_by_tree[m.message_tree_id] = message_list
else:
message_list.append(m)

return replies_by_tree, prompts

def _purge_message_internal(self, message_id: UUID) -> None:
"""This internal function deletes a single message. It does not take care of
descendants, children_count in parent etc."""

sql_purge_message = """
DELETE FROM journal j USING message m WHERE j.message_id = :message_id;
DELETE FROM message_embedding e WHERE e.message_id = :message_id;
DELETE FROM message_toxicity t WHERE t.message_id = :message_id;
DELETE FROM text_labels l WHERE l.message_id = :message_id;
-- delete all ranking results that contain message
DELETE FROM message_reaction r WHERE r.payload_type = 'RankingReactionPayload' AND r.task_id IN (
SELECT t.id FROM message m
JOIN task t ON m.parent_id = t.parent_message_id
WHERE m.id = :message_id);
-- delete task which inserted message
DELETE FROM task t using message m WHERE t.id = m.task_id AND m.id = :message_id;
DELETE FROM task t WHERE t.parent_message_id = :message_id;
DELETE FROM message WHERE id = :message_id;
"""
r = self.db.execute(text(sql_purge_message), {"message_id": message_id})
logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.")

def purge_message_tree(self, message_tree_id: UUID) -> None:
sql_purge_message_tree = """
DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;
DELETE FROM message_embedding e USING message m WHERE e.message_id = m.Id AND m.message_tree_id = :message_tree_id;
DELETE FROM message_toxicity t USING message m WHERE t.message_id = m.Id AND m.message_tree_id = :message_tree_id;
DELETE FROM text_labels l USING message m WHERE l.message_id = m.Id AND m.message_tree_id = :message_tree_id;
DELETE FROM message_reaction r USING task t WHERE r.task_id = t.id AND t.message_tree_id = :message_tree_id;
DELETE FROM task t WHERE t.message_tree_id = :message_tree_id;
DELETE FROM message_tree_state WHERE message_tree_id = :message_tree_id;
DELETE FROM message WHERE message_tree_id = :message_tree_id;
"""
r = self.db.execute(text(sql_purge_message_tree), {"message_tree_id": message_tree_id})
logger.debug(f"purge_message_tree updated({message_tree_id=}) {r.rowcount} rows.")

@managed_tx_method(CommitMode.FLUSH)
def purge_user_messages(self, user_id: UUID, purge_initial_prompts: bool = True):

# find all affected message trees
replies_by_tree, prompts = self.get_user_messages_by_tree(user_id)

# remove all trees based on inital prompts of the user
if purge_initial_prompts:
for p in prompts:
self.purge_message_tree(p.message_tree_id)
if p.message_tree_id in replies_by_tree:
del replies_by_tree[p.message_tree_id]

# patch all affected message trees
for tree_id, replies in replies_by_tree.items():
bad_parent_ids = set(m.id for m in replies)

tree_messages = self.pr.fetch_message_tree(tree_id)
by_id = {m.id: m for m in tree_messages}

def ancestor_ids(msg: Message) -> list[UUID]:
t = []
while msg.parent_id is not None:
msg = by_id[msg.parent_id]
t.append(msg.id)
return t

def is_descendant_of_deleted(m: Message) -> bool:
if m.id in bad_parent_ids:
return True
ancestors = ancestor_ids(m)
if any(a in bad_parent_ids for a in ancestors):
return True
return False

# start with deepest messages first
tree_messages.sort(key=lambda x: x.depth, reverse=True)
for m in tree_messages:
if is_descendant_of_deleted(m):
self._purge_message_internal(m.id)

# update childern counts
self.pr.update_children_counts(m.message_tree_id)

# reactivate tree
logger.info(f"reactivating message tree {tree_id}")
mts = self.pr.fetch_tree_state(tree_id)
mts.active = True
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
self.check_condition_for_growing_state(tree_id)

@managed_tx_method(CommitMode.FLUSH)
def purge_user(self, user_id: UUID) -> None:
self.purge_user_messages(user_id, purge_initial_prompts=True)

# delete all remaining rows and ban user
sql_purge_user = """
DELETE FROM journal WHERE user_id = :user_id;
DELETE FROM message_reaction WHERE user_id = :user_id;
DELETE FROM task WHERE user_id = :user_id;
DELETE FROM message WHERE user_id = :user_id;
DELETE FROM user_stats WHERE user_id = :user_id;
UPDATE "user" SET deleted = TRUE, enabled = FALSE WHERE id = :user_id;
"""

r = self.db.execute(text(sql_purge_user), {"user_id": user_id})
logger.debug(f"purge_user({user_id=}): {r.rowcount} rows.")


if __name__ == "__main__":
from oasst_backend.api.deps import api_auth
Expand All @@ -994,6 +1128,10 @@ def stats(self) -> TreeManagerStats:
tm = TreeManager(db, pr, cfg)
tm.ensure_tree_states()

# tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
# tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"))
# db.commit()

# print("query_num_active_trees", tm.query_num_active_trees())
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
# print("query_replies_need_review", tm.query_replies_need_review())
Expand Down
8 changes: 6 additions & 2 deletions oasst-shared/oasst_shared/exceptions/oasst_api_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class OasstErrorCode(IntEnum):
RATING_OUT_OF_RANGE = 2002
INVALID_RANKING_VALUE = 2003
INVALID_TASK_TYPE = 2004
USER_NOT_SPECIFIED = 2005

NO_MESSAGE_TREE_FOUND = 2006
NO_REPLIES_FOUND = 2007
INVALID_MESSAGE = 2008
Expand All @@ -62,11 +62,15 @@ class OasstErrorCode(IntEnum):
TASK_NOT_COLLECTIVE = 2106
TASK_NOT_ASSIGNED_TO_USER = 2106
TASK_UNEXPECTED_PAYLOAD_TYPE_ = 2107
USER_NOT_FOUND = 2200

# 3000-4000: external resources
HUGGINGFACE_API_ERROR = 3001

# 4000-5000: user
USER_NOT_SPECIFIED = 4000
USER_DISABLED = 4001
USER_NOT_FOUND = 4002


class OasstError(Exception):
"""Base class for Open-Assistant exceptions."""
Expand Down

0 comments on commit 335af5d

Please sign in to comment.