diff --git a/docs/reference/python/python_client.rst b/docs/reference/python/python_client.rst index f9a48f833b..1e2f0f0cea 100644 --- a/docs/reference/python/python_client.rst +++ b/docs/reference/python/python_client.rst @@ -15,7 +15,7 @@ Methods ------- .. automodule:: rubrix - :members: init, log, load, copy, delete, set_workspace, get_workspace + :members: init, log, load, copy, delete, set_workspace, get_workspace, delete_records .. _python ref records: diff --git a/src/rubrix/__init__.py b/src/rubrix/__init__.py index bfc3f8af9e..d36897576d 100644 --- a/src/rubrix/__init__.py +++ b/src/rubrix/__init__.py @@ -32,6 +32,7 @@ from rubrix.client.api import ( copy, delete, + delete_records, get_workspace, init, load, @@ -68,6 +69,7 @@ "get_workspace", "init", "load", + "delete_records", "log", "log_async", "set_workspace", diff --git a/src/rubrix/client/api.py b/src/rubrix/client/api.py index a925285466..18b10a80d0 100644 --- a/src/rubrix/client/api.py +++ b/src/rubrix/client/api.py @@ -20,7 +20,7 @@ from asyncio import Future from functools import wraps from inspect import signature -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from tqdm.auto import tqdm @@ -401,6 +401,48 @@ async def log_async( # Creating a composite BulkResponse with the total processed and failed return BulkResponse(dataset=name, processed=processed, failed=failed) + def delete_records( + self, + name: str, + query: Optional[str] = None, + ids: Optional[List[Union[str, int]]] = None, + discard_only: bool = False, + discard_when_forbidden: bool = True, + ) -> Tuple[int, int]: + """Delete records from a Rubrix dataset. + + Args: + name: The dataset name. + query: An ElasticSearch query with the `query string syntax + `_ + ids: If provided, deletes dataset records with given ids. + discard_only: If `True`, matched records won't be deleted. Instead, they will be marked as `Discarded` + discard_when_forbidden: Only super-user or dataset creator can delete records from a dataset. + So, running "hard" deletion for other users will raise an `ForbiddenApiError` error. + If this parameter is `True`, the client API will automatically try to mark as ``Discarded`` + records instead. Default, `True` + + Returns: + The total of matched records and real number of processed errors. These numbers could not + be the same if some data conflicts are found during operations (some matched records change during + deletion). + + Examples: + >>> ## Delete by id + >>> import rubrix as rb + >>> rb.delete_records(name="example-dataset", ids=[1,3,5]) + >>> ## Discard records by query + >>> import rubrix as rb + >>> rb.delete_records(name="example-dataset", query="metadata.code=33", discard_only=True) + """ + return self.datasets.delete_records( + name=name, + query=query, + ids=ids, + mark_as_discarded=discard_only, + discard_when_forbidden=discard_when_forbidden, + ) + def load( self, name: str, @@ -649,5 +691,10 @@ def load(*args, **kwargs): return active_api().load(*args, **kwargs) +@api_wrapper(Api.delete_records) +def delete_records(*args, **kwargs): + return active_api().delete_records(*args, **kwargs) + + class InputValueError(RubrixClientError): pass diff --git a/src/rubrix/client/apis/datasets.py b/src/rubrix/client/apis/datasets.py index 12cf7c2030..1a4a48f9c8 100644 --- a/src/rubrix/client/apis/datasets.py +++ b/src/rubrix/client/apis/datasets.py @@ -1,11 +1,16 @@ +import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from pydantic import BaseModel, Field from rubrix.client.apis import AbstractApi, api_compatibility -from rubrix.client.sdk.commons.errors import AlreadyExistsApiError, NotFoundApiError +from rubrix.client.sdk.commons.errors import ( + AlreadyExistsApiError, + ForbiddenApiError, + NotFoundApiError, +) from rubrix.client.sdk.datasets.api import get_dataset from rubrix.client.sdk.datasets.models import TaskType @@ -126,8 +131,58 @@ def configure(self, name: str, settings: Settings): ds = self.find_by_name(name) self.__save_settings__(dataset=ds, settings=settings) - def __save_settings__(self, dataset: _DatasetApiModel, settings: Settings): + def delete_records( + self, + name: str, + query: Optional[str] = None, + ids: Optional[List[Union[str, int]]] = None, + mark_as_discarded: bool = False, + discard_when_forbidden: bool = True, + ) -> Tuple[int, int]: + """ + Tries to delete records in a dataset for a given query/ids list. + + Args: + name: The dataset name + query: The query matching records + ids: A list of records ids. If provided, the query param will be ignored + mark_as_discarded: If `True`, the matched records will be marked as `Discarded` instead + of delete them + discard_when_forbidden: Only super-user or dataset creator can delete records from a dataset. + So, running "hard" deletion for other users will raise an `ForbiddenApiError` error. + If this parameter is `True`, the client API will automatically try to mark as ``Discarded`` + records instead. + Returns: + The total of matched records and real number of processed errors. These numbers could not + be the same if some data conflicts are found during operations (some matched records change during + deletion). + + """ + with api_compatibility(self, min_version="0.18"): + try: + response = self.__client__.delete( + path=f"{self._API_PREFIX}/{name}/data?mark_as_discarded={mark_as_discarded}", + json={"ids": ids} if ids else {"query_text": query}, + ) + return response["matched"], response["processed"] + except ForbiddenApiError as faer: + if discard_when_forbidden: + warnings.warn( + message=f"{faer}. Records will be discarded instead", + category=UserWarning, + ) + return self.delete_records( + name=name, + query=query, + ids=ids, + mark_as_discarded=True, + discard_when_forbidden=False, # Next time will raise the error + ) + else: + raise faer + + def __save_settings__(self, dataset: _DatasetApiModel, settings: Settings): if __TASK_TO_SETTINGS__.get(dataset.task) != type(settings): raise ValueError( f"The provided settings type {type(settings)} cannot be applied to dataset." diff --git a/src/rubrix/client/sdk/client.py b/src/rubrix/client/sdk/client.py index fd1865a1ce..dd1417bff9 100644 --- a/src/rubrix/client/sdk/client.py +++ b/src/rubrix/client/sdk/client.py @@ -120,6 +120,18 @@ def put(self, path: str, *args, **kwargs): ) return build_raw_response(response).parsed + @with_httpx_error_handler + def delete(self, path: str, *args, **kwargs): + path = self._normalize_path(path) + response = self.__httpx__.request( + method="DELETE", + url=path, + headers=self.get_headers(), + *args, + **kwargs, + ) + return build_raw_response(response).parsed + @with_httpx_error_handler def stream(self, path: str, *args, **kwargs): return self.__httpx__.stream( diff --git a/src/rubrix/server/apis/v0/handlers/records_deletion.py b/src/rubrix/server/apis/v0/handlers/records_deletion.py new file mode 100644 index 0000000000..e55e12e9f2 --- /dev/null +++ b/src/rubrix/server/apis/v0/handlers/records_deletion.py @@ -0,0 +1,61 @@ +from typing import Optional, Union + +from fastapi import APIRouter, Depends, Query, Security +from pydantic import BaseModel + +from rubrix.client.sdk.token_classification.models import TokenClassificationQuery +from rubrix.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies +from rubrix.server.apis.v0.models.text2text import Text2TextQuery +from rubrix.server.apis.v0.models.text_classification import TextClassificationQuery +from rubrix.server.security import auth +from rubrix.server.security.model import User +from rubrix.server.services.datasets import DatasetsService +from rubrix.server.services.storage.service import RecordsStorageService + + +def configure_router(router: APIRouter): + QueryType = Union[TextClassificationQuery, TokenClassificationQuery, Text2TextQuery] + + class DeleteRecordsResponse(BaseModel): + matched: int + processed: int + + @router.delete( + "/{name}/data", + operation_id="delete_dataset_records", + response_model=DeleteRecordsResponse, + ) + async def delete_dataset_records( + name: str, + query: Optional[QueryType] = None, + mark_as_discarded: bool = Query( + default=False, + title="If True, matched records won't be deleted." + " Instead of that, the record status will be changed to `Discarded`", + ), + request_deps: CommonTaskHandlerDependencies = Depends(), + service: DatasetsService = Depends(DatasetsService.get_instance), + storage: RecordsStorageService = Depends(RecordsStorageService.get_instance), + current_user: User = Security(auth.get_user, scopes=[]), + ): + found = service.find_by_name( + user=current_user, + name=name, + workspace=request_deps.workspace, + ) + + result = await storage.delete_records( + user=current_user, + dataset=found, + query=query, + mark_as_discarded=mark_as_discarded, + ) + + return DeleteRecordsResponse( + matched=result.processed, + processed=result.deleted or result.discarded, + ) + + +router = APIRouter(tags=["datasets"], prefix="/datasets") +configure_router(router) diff --git a/src/rubrix/server/daos/records.py b/src/rubrix/server/daos/records.py index eefc16e4a6..d9a5de18ee 100644 --- a/src/rubrix/server/daos/records.py +++ b/src/rubrix/server/daos/records.py @@ -14,7 +14,7 @@ # limitations under the License. import datetime -from typing import Any, Dict, Iterable, List, Optional, Type +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type from fastapi import Depends @@ -200,3 +200,24 @@ def get_dataset_schema(self, dataset: DatasetDB) -> Dict[str, Any]: """Return inner elasticsearch index configuration""" schema = self._es.get_mappings(id=dataset.id) return schema + + async def delete_records_by_query( + self, + dataset: DatasetDB, + query: Optional[BackendRecordsQuery] = None, + ) -> Tuple[int, int]: + total, deleted = await self._es.delete_records_by_query( + id=dataset.id, query=query + ) + return total, deleted + + async def update_records_by_query( + self, + dataset: DatasetDB, + query: Optional[BackendRecordsQuery] = None, + **content, + ) -> Tuple[int, int]: + total, updated = await self._es.update_records_content( + id=dataset.id, content=content, query=query + ) + return total, updated diff --git a/src/rubrix/server/routes.py b/src/rubrix/server/routes.py index 8432dea29d..28bef1ffdf 100644 --- a/src/rubrix/server/routes.py +++ b/src/rubrix/server/routes.py @@ -23,6 +23,7 @@ from rubrix.server.apis.v0.handlers import datasets as datasets from rubrix.server.apis.v0.handlers import info as info from rubrix.server.apis.v0.handlers import metrics as tasks +from rubrix.server.apis.v0.handlers import records_deletion as records_deletion from rubrix.server.apis.v0.handlers import users as users from rubrix.server.errors.base_errors import __ALL__ @@ -33,5 +34,11 @@ dependencies = [] -for router in [users.router, datasets.router, info.router, tasks.router]: +for router in [ + users.router, + datasets.router, + info.router, + tasks.router, + records_deletion.router, +]: api_router.include_router(router, dependencies=dependencies) diff --git a/src/rubrix/server/services/storage/service.py b/src/rubrix/server/services/storage/service.py index 31883db7fe..6686c91385 100644 --- a/src/rubrix/server/services/storage/service.py +++ b/src/rubrix/server/services/storage/service.py @@ -1,14 +1,26 @@ -from typing import List, Type +import dataclasses +from typing import List, Optional, Type from fastapi import Depends from rubrix.server.commons import telemetry from rubrix.server.commons.config import TasksFactory +from rubrix.server.commons.models import TaskStatus from rubrix.server.daos.records import DatasetRecordsDAO +from rubrix.server.errors import ForbiddenOperationError +from rubrix.server.security.model import User from rubrix.server.services.datasets import ServiceDataset +from rubrix.server.services.search.model import ServiceBaseRecordsQuery from rubrix.server.services.tasks.commons import ServiceRecord +@dataclasses.dataclass +class DeleteRecordsOut: + processed: int = 0 + discarded: int = 0 + deleted: int = 0 + + class RecordsStorageService: _INSTANCE: "RecordsStorageService" = None @@ -44,3 +56,34 @@ async def store_records( records=records, record_class=record_type, ) + + async def delete_records( + self, + user: User, + dataset: ServiceDataset, + query: Optional[ServiceBaseRecordsQuery] = None, + mark_as_discarded: bool = False, + ) -> DeleteRecordsOut: + processed, discarded, deleted = None, None, None + if mark_as_discarded: + processed, discarded = await self.__dao__.update_records_by_query( + dataset, + query=query, + status=TaskStatus.discarded, + ) + else: + if not user.is_superuser() and user.username != dataset.created_by: + raise ForbiddenOperationError( + f"You don't have the necessary permissions to delete records on this dataset. " + "Only dataset creators or administrators can delete datasets" + ) + + processed, deleted = await self.__dao__.delete_records_by_query( + dataset, query=query + ) + + return DeleteRecordsOut( + processed=processed or 0, + discarded=discarded or 0, + deleted=deleted or 0, + ) diff --git a/tests/functional_tests/datasets/__init__.py b/tests/functional_tests/datasets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/functional_tests/datasets/test_delete_records_from_datasets.py b/tests/functional_tests/datasets/test_delete_records_from_datasets.py new file mode 100644 index 0000000000..39befaf13d --- /dev/null +++ b/tests/functional_tests/datasets/test_delete_records_from_datasets.py @@ -0,0 +1,96 @@ +import time + +import pytest + +from rubrix.client.sdk.commons.errors import ForbiddenApiError + + +def test_delete_records_from_dataset(mocked_client): + dataset = "test_delete_records_from_dataset" + import rubrix as rb + + rb.delete(dataset) + rb.log( + name=dataset, + records=[ + rb.TextClassificationRecord( + id=i, text="This is the text", metadata=dict(idx=i) + ) + for i in range(0, 50) + ], + ) + + matched, processed = rb.delete_records(name=dataset, ids=[10], discard_only=True) + assert matched, processed == (1, 1) + + ds = rb.load(name=dataset) + assert len(ds) == 50 + + time.sleep(1) + matched, processed = rb.delete_records( + name=dataset, query="id:10", discard_only=False + ) + assert matched, processed == (1, 1) + + time.sleep(1) + ds = rb.load(name=dataset) + assert len(ds) == 49 + + +def test_delete_records_without_permission(mocked_client): + dataset = "test_delete_records_without_permission" + import rubrix as rb + + rb.delete(dataset) + rb.log( + name=dataset, + records=[ + rb.TextClassificationRecord( + id=i, text="This is the text", metadata=dict(idx=i) + ) + for i in range(0, 50) + ], + ) + try: + mocked_client.change_current_user("mock-user") + matched, processed = rb.delete_records( + name=dataset, ids=[10], discard_only=True + ) + assert matched, processed == (1, 1) + + with pytest.raises(ForbiddenApiError): + rb.delete_records( + name=dataset, + query="id:10", + discard_only=False, + discard_when_forbidden=False, + ) + + matched, processed = rb.delete_records( + name=dataset, + query="id:10", + discard_only=False, + discard_when_forbidden=True, + ) + assert matched, processed == (1, 1) + finally: + mocked_client.reset_default_user() + + +def test_delete_records_with_unmatched_records(mocked_client): + dataset = "test_delete_records_with_unmatched_records" + import rubrix as rb + + rb.delete(dataset) + rb.log( + name=dataset, + records=[ + rb.TextClassificationRecord( + id=i, text="This is the text", metadata=dict(idx=i) + ) + for i in range(0, 50) + ], + ) + + matched, processed = rb.delete_records(dataset, ids=["you-wont-find-me-here"]) + assert (matched, processed) == (0, 0) diff --git a/tests/helpers.py b/tests/helpers.py index dec39e9c26..6e2ffb9eec 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -34,6 +34,8 @@ def change_current_user(self, username): rb_api._user = new_user rb_api.set_workspace(default_user.username) rb_api.client.token = new_user.api_key + self._header[API_KEY_HEADER_NAME] = new_user.api_key + self._header[RUBRIX_WORKSPACE_HEADER_NAME] = "rubrix" def reset_default_user(self): default_user = auth.users.__dao__.__users__["rubrix"] @@ -42,6 +44,7 @@ def reset_default_user(self): rb_api._user = default_user rb_api.client.token = default_user.api_key rb_api.client.headers.pop(RUBRIX_WORKSPACE_HEADER_NAME) + self._header[API_KEY_HEADER_NAME] = default_user.api_key def add_workspaces_to_rubrix_user(self, workspaces: List[str]): rubrix_user = auth.users.__dao__.__users__["rubrix"] @@ -63,6 +66,11 @@ def delete(self, *args, **kwargs): headers = {**self._header, **request_headers} return self._client.delete(*args, headers=headers, **kwargs) + def request(self, *args, **kwargs): + request_headers = kwargs.pop("headers", {}) + headers = {**self._header, **request_headers} + return self._client.request(*args, headers=headers, **kwargs) + def post(self, *args, **kwargs): request_headers = kwargs.pop("headers", {}) headers = {**self._header, **request_headers} diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py index fd99debe8b..c94880c873 100644 --- a/tests/server/datasets/test_api.py +++ b/tests/server/datasets/test_api.py @@ -220,12 +220,59 @@ def delete_dataset(client, dataset, workspace: Optional[str] = None): assert client.delete(url).status_code == 200 -def create_mock_dataset(client, dataset): +def create_mock_dataset(client, dataset, records=[]): client.post( f"/api/datasets/{dataset}/TextClassification:bulk", json=TextClassificationBulkRequest( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, - records=[], + records=records, ).dict(by_alias=True), ) + + +def test_delete_records(mocked_client): + dataset_name = "test_delete_records" + delete_dataset(mocked_client, dataset_name) + + create_mock_dataset( + mocked_client, + dataset=dataset_name, + records=[ + { + "id": i, + "inputs": {"text": f"This is a text for id {i}"}, + } + for i in range(1, 100) + ], + ) + response = mocked_client.delete( + f"/api/datasets/{dataset_name}/data", json={"ids": [1]} + ) + assert response.status_code == 200 + assert response.json() == {"matched": 1, "processed": 1} + + try: + mocked_client.change_current_user("mock-user") + response = mocked_client.delete(f"/api/datasets/{dataset_name}/data") + assert response.status_code == 403 + assert response.json() == { + "detail": { + "code": "rubrix.api.errors::ForbiddenOperationError", + "params": { + "detail": "You don't have the necessary permissions to delete records on this dataset." + " Only dataset creators or administrators can delete datasets" + }, + } + } + + response = mocked_client.delete( + f"/api/datasets/{dataset_name}/data?mark_as_discarded=true" + ) + assert response.status_code == 200 + assert response.json() == { + "matched": 99, + "processed": 98, + } # different values are caused by conflicts found + finally: + mocked_client.reset_default_user()