Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add admin purge user function #834

Merged
merged 6 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 27 additions & 0 deletions backend/oasst_backend/api/v1/admin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from uuid import UUID

import pydantic
from fastapi import APIRouter, Depends
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.models.api_client import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.tree_manager import TreeManager
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from starlette.status import HTTP_204_NO_CONTENT

router = APIRouter()

Expand Down Expand Up @@ -29,3 +36,23 @@ async def create_api_client(
)
logger.info(f"Created api_client with key {api_client.api_key}")
return api_client.api_key


@router.post("/purge_user/{user_id}", response_model=None, status_code=HTTP_204_NO_CONTENT)
async def purge_user(
user_id: UUID,
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> str:
assert api_client.trusted

@managed_tx_function(CommitMode.COMMIT)
def purge_tx(session: deps.Session):
pr = PromptRepository(session, api_client)
user = pr.user_repository.get_user(user_id)
logger.warning(
f"PURGE USER: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
)
tm = TreeManager(session, pr)
tm.purge_user(user_id)

purge_tx()
4 changes: 4 additions & 0 deletions backend/oasst_backend/api/v1/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def request_task(

try:
pr = PromptRepository(db, api_client, client_user=request.user)
pr.ensure_user_is_enabled()

tm = TreeManager(db, pr)
task, message_tree_id, parent_message_id = tm.next_task(request.type)
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
Expand Down Expand Up @@ -85,6 +87,7 @@ def tasks_acknowledge(

try:
pr = PromptRepository(db, api_client)
pr.ensure_user_is_enabled()

# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
Expand Down Expand Up @@ -113,6 +116,7 @@ def tasks_acknowledge_failure(
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client)
pr.ensure_user_is_enabled()
pr.task_repository.acknowledge_task_failure(task_id)
except (KeyError, RuntimeError):
logger.exception("Failed to not acknowledge task.")
Expand Down
14 changes: 11 additions & 3 deletions backend/oasst_backend/prompt_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def __init__(
)
self.journal = JournalWriter(db, api_client, self.user)

def ensure_user_is_enabled(self):
if self.user is None or self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)

if self.user.deleted or not self.user.enabled:
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)

def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
validate_frontend_message_id(frontend_message_id)
message: Message = (
Expand Down Expand Up @@ -146,6 +153,8 @@ def store_text_reply(
review_result: bool = False,
check_tree_state: bool = True,
) -> Message:
self.ensure_user_is_enabled()

validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)

Expand Down Expand Up @@ -354,8 +363,7 @@ def insert_message_embedding(self, message_id: UUID, model: str, embedding: List

@managed_tx_method(CommitMode.FLUSH)
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
self.ensure_user_is_enabled()

container = PayloadContainer(payload=payload)
reaction = MessageReaction(
Expand Down Expand Up @@ -499,7 +507,7 @@ def fetch_random_initial_prompts(self, size: int = 5):
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
return messages

def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True):
def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True) -> list[Message]:
qry = self.db.query(Message).filter(Message.message_tree_id == message_tree_id)
if reviewed:
qry = qry.filter(Message.review_result)
Expand Down
146 changes: 146 additions & 0 deletions backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def _determine_task_availability_internal(
return task_count_by_type

def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
self.pr.ensure_user_is_enabled()

num_active_trees = self.query_num_active_trees()
extendible_parents = self.query_extendible_parents()
prompts_need_review = self.query_prompts_need_review()
Expand All @@ -212,6 +214,8 @@ def next_task(

logger.debug("TreeManager.next_task()")

self.pr.ensure_user_is_enabled()

num_active_trees = self.query_num_active_trees()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
Expand Down Expand Up @@ -445,6 +449,7 @@ def next_task(
@async_managed_tx_method(CommitMode.COMMIT)
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
pr = self.pr
pr.ensure_user_is_enabled()
match type(interaction):
case protocol_schema.TextReplyToMessage:
logger.info(
Expand Down Expand Up @@ -978,6 +983,144 @@ def stats(self) -> TreeManagerStats:
message_counts=self.tree_message_count_stats(only_active=True),
)

def get_user_messages_by_tree(self, user_id: UUID) -> Tuple[dict[UUID, list[Message]], list[Message]]:
"""Returns a dict with replies by tree (excluding initial prompts) and list of initial prompts
associated with user_id."""

# query all messages of the user
qry = self.db.query(Message).filter(Message.user_id == user_id)

prompts: list[Message] = []
replies_by_tree: dict[UUID, list[Message]] = {}

# walk over result set and distinguish between initial prompts and replies
for m in qry:
m: Message

if m.message_tree_id == m.id:
prompts.append(m)
else:
message_list = replies_by_tree.get(m.message_tree_id)
if message_list is None:
message_list = [m]
replies_by_tree[m.message_tree_id] = message_list
else:
message_list.append(m)

return replies_by_tree, prompts

def _purge_message_internal(self, message_id: UUID) -> None:
"""This internal function deletes a single message. It does not take care of
descendants, children_count in parent etc."""

sql_purge_message = """
DELETE FROM journal j USING message m WHERE j.message_id = :message_id;
DELETE FROM message_embedding e WHERE e.message_id = :message_id;
DELETE FROM message_toxicity t WHERE t.message_id = :message_id;
DELETE FROM text_labels l WHERE l.message_id = :message_id;
-- delete all ranking results that contain message
DELETE FROM message_reaction r WHERE r.payload_type = 'RankingReactionPayload' AND r.task_id IN (
SELECT t.id FROM message m
JOIN task t ON m.parent_id = t.parent_message_id
WHERE m.id = :message_id);
-- delete task which inserted message
DELETE FROM task t using message m WHERE t.id = m.task_id AND m.id = :message_id;
DELETE FROM task t WHERE t.parent_message_id = :message_id;
DELETE FROM message WHERE id = :message_id;
"""
r = self.db.execute(text(sql_purge_message), {"message_id": message_id})
logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.")

def purge_message_tree(self, message_tree_id: UUID) -> None:
sql_purge_message_tree = """
DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;
DELETE FROM message_embedding e USING message m WHERE e.message_id = m.Id AND m.message_tree_id = :message_tree_id;
DELETE FROM message_toxicity t USING message m WHERE t.message_id = m.Id AND m.message_tree_id = :message_tree_id;
DELETE FROM text_labels l USING message m WHERE l.message_id = m.Id AND m.message_tree_id = :message_tree_id;
DELETE FROM message_reaction r USING task t WHERE r.task_id = t.id AND t.message_tree_id = :message_tree_id;
DELETE FROM task t WHERE t.message_tree_id = :message_tree_id;
DELETE FROM message_tree_state WHERE message_tree_id = :message_tree_id;
DELETE FROM message WHERE message_tree_id = :message_tree_id;
"""
r = self.db.execute(text(sql_purge_message_tree), {"message_tree_id": message_tree_id})
logger.debug(f"purge_message_tree updated({message_tree_id=}) {r.rowcount} rows.")

@managed_tx_method(CommitMode.FLUSH)
def purge_messages_of_user(self, user_id: UUID, purge_initial_prompts: bool = True):

# find all affected message trees
replies_by_tree, prompts = self.get_user_messages_by_tree(user_id)

# remove all trees based on inital prompts of the user
if purge_initial_prompts:
for p in prompts:
self.purge_message_tree(p.message_tree_id)
if p.message_tree_id in replies_by_tree:
del replies_by_tree[p.message_tree_id]

# patch all affected message trees
for tree_id, replies in replies_by_tree.items():
bad_parent_ids = set(m.id for m in replies)

tree_messages = self.pr.fetch_message_tree(tree_id)
by_id = {m.id: m for m in tree_messages}

def ancestor_ids(msg: Message) -> list[UUID]:
t = []
while msg.parent_id is not None:
msg = by_id[msg.parent_id]
t.append(msg.id)
return t

def is_descendant_of_deleted(m: Message) -> bool:
if m.id in bad_parent_ids:
return True
ancestors = ancestor_ids(m)
if any(a in bad_parent_ids for a in ancestors):
return True
return False

# start with deepest messages first
tree_messages.sort(key=lambda x: x.depth, reverse=True)
for m in tree_messages:
if is_descendant_of_deleted(m):
self._purge_message_internal(m.id)

# try to update child count
if m.id in bad_parent_ids:
assert m.parent_id is not None
parent = by_id[m.parent_id]
if parent and not is_descendant_of_deleted(parent):
parent.children_count -= 1
self.db.add(parent)

# update childern counts
self.db.flush()

# reactivate tree
logger.info(f"reactivating message tree {tree_id}")
mts = self.pr.fetch_tree_state(tree_id)
mts.active = True
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
self.check_condition_for_growing_state(tree_id)

@managed_tx_method(CommitMode.FLUSH)
def purge_user(self, user_id: UUID) -> None:
self.purge_messages_of_user(user_id, purge_initial_prompts=True)

# delete all remaining rows and ban user
sql_purge_user = """
DELETE FROM journal WHERE user_id = :user_id;
DELETE FROM message_reaction WHERE user_id = :user_id;
DELETE FROM task WHERE user_id = :user_id;
DELETE FROM message WHERE user_id = :user_id;
DELETE FROM user_stats WHERE user_id = :user_id;
UPDATE "user" SET deleted = TRUE, enabled = FALSE WHERE id = :user_id;
"""

r = self.db.execute(text(sql_purge_user), {"user_id": user_id})
logger.debug(f"purge_user({user_id=}): {r.rowcount} rows.")


if __name__ == "__main__":
from oasst_backend.api.deps import api_auth
Expand All @@ -994,6 +1137,9 @@ def stats(self) -> TreeManagerStats:
tm = TreeManager(db, pr, cfg)
tm.ensure_tree_states()

# tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"))
# db.commit()

# print("query_num_active_trees", tm.query_num_active_trees())
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
# print("query_replies_need_review", tm.query_replies_need_review())
Expand Down
8 changes: 6 additions & 2 deletions oasst-shared/oasst_shared/exceptions/oasst_api_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class OasstErrorCode(IntEnum):
RATING_OUT_OF_RANGE = 2002
INVALID_RANKING_VALUE = 2003
INVALID_TASK_TYPE = 2004
USER_NOT_SPECIFIED = 2005

NO_MESSAGE_TREE_FOUND = 2006
NO_REPLIES_FOUND = 2007
INVALID_MESSAGE = 2008
Expand All @@ -62,11 +62,15 @@ class OasstErrorCode(IntEnum):
TASK_NOT_COLLECTIVE = 2106
TASK_NOT_ASSIGNED_TO_USER = 2106
TASK_UNEXPECTED_PAYLOAD_TYPE_ = 2107
USER_NOT_FOUND = 2200

# 3000-4000: external resources
HUGGINGFACE_API_ERROR = 3001

# 4000-5000: user
USER_NOT_SPECIFIED = 4000
USER_DISABLED = 4001
USER_NOT_FOUND = 4002


class OasstError(Exception):
"""Base class for Open-Assistant exceptions."""
Expand Down