diff --git a/src/rubrix/server/services/datasets.py b/src/rubrix/server/services/datasets.py index 2e338f2b28..f035e46a16 100644 --- a/src/rubrix/server/services/datasets.py +++ b/src/rubrix/server/services/datasets.py @@ -99,7 +99,9 @@ def find_by_name( if found_ds is None: raise EntityNotFoundError(name=name, type=Dataset) if found_ds.owner and owner and found_ds.owner != owner: - raise ForbiddenOperationError() + raise EntityNotFoundError( + name=name, type=Dataset + ) if user.is_superuser() else ForbiddenOperationError() return cast(Dataset, found_ds) @@ -115,9 +117,13 @@ def __find_by_name_with_superuser_fallback__( name=name, owner=owner, task=task, as_dataset_class=as_dataset_class ) if not found_ds and user.is_superuser(): - found_ds = self.__dao__.find_by_name( - name=name, owner=None, task=task, as_dataset_class=as_dataset_class - ) + try: + found_ds = self.__dao__.find_by_name( + name=name, owner=None, task=task, as_dataset_class=as_dataset_class + ) + except WrongTaskError: + # A dataset exists in a different workspace and with a different task + pass return found_ds def delete(self, user: User, dataset: Dataset): diff --git a/tests/conftest.py b/tests/conftest.py index f9a0967724..e936cee6c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ @pytest.fixture -def mocked_client(monkeypatch): +def mocked_client(monkeypatch) -> SecuredClient: with TestClient(app, raise_server_exceptions=False) as _client: client = SecuredClient(_client) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 79c31dfae5..3f86286ba3 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -3,7 +3,6 @@ import rubrix as rb from rubrix import TextClassificationSettings, TokenClassificationSettings from rubrix.client import api -from rubrix.client.sdk.commons.errors import AlreadyExistsApiError @pytest.mark.parametrize( diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py index 1bb5201412..313c3011f3 100644 --- a/tests/functional_tests/test_log_for_text_classification.py +++ b/tests/functional_tests/test_log_for_text_classification.py @@ -3,6 +3,7 @@ import rubrix as rb from rubrix.client.sdk.commons.errors import BadRequestApiError, ValidationApiError from rubrix.server.apis.v0.settings.server import settings +from tests.helpers import SecuredClient def test_log_records_with_multi_and_single_label_task(mocked_client): @@ -50,6 +51,32 @@ def test_delete_and_create_for_different_task(mocked_client): rb.load(dataset) +def test_log_data_in_several_workspaces(mocked_client: SecuredClient): + + workspace = "test-ws" + dataset = "test_log_data_in_several_workspaces" + text = "This is a text" + + mocked_client.add_workspaces_to_rubrix_user([workspace]) + + curr_ws = rb.get_workspace() + for ws in [curr_ws, workspace]: + rb.set_workspace(ws) + rb.delete(dataset) + + rb.set_workspace(curr_ws) + rb.log(rb.TextClassificationRecord(id=0, inputs=text), name=dataset) + + rb.set_workspace(workspace) + rb.log(rb.TextClassificationRecord(id=1, inputs=text), name=dataset) + ds = rb.load(dataset) + assert len(ds) == 1 + + rb.set_workspace(curr_ws) + ds = rb.load(dataset) + assert len(ds) == 1 + + def test_search_keywords(mocked_client): dataset = "test_search_keywords" from datasets import load_dataset diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py index 66aea43511..37a04cdbd8 100644 --- a/tests/server/datasets/test_api.py +++ b/tests/server/datasets/test_api.py @@ -12,9 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from rubrix.server.apis.v0.models.commons.model import TaskType from rubrix.server.apis.v0.models.datasets import Dataset from rubrix.server.apis.v0.models.text_classification import TextClassificationBulkData +from tests.helpers import SecuredClient def test_delete_dataset(mocked_client): @@ -62,6 +65,48 @@ def test_create_dataset(mocked_client): assert response.status_code == 409 +def test_fetch_dataset_using_workspaces(mocked_client: SecuredClient): + ws = "mock-ws" + dataset_name = "test_fetch_dataset_using_workspaces" + mocked_client.add_workspaces_to_rubrix_user([ws]) + + delete_dataset(mocked_client, dataset_name, workspace=ws) + delete_dataset(mocked_client, dataset_name) + request = dict( + name=dataset_name, + task=TaskType.text_classification, + ) + response = mocked_client.post( + f"/api/datasets?workspace={ws}", + json=request, + ) + + assert response.status_code == 200, response.json() + dataset = Dataset.parse_obj(response.json()) + assert dataset.created_by == "rubrix" + assert dataset.name == dataset_name + assert dataset.owner == ws + assert dataset.task == TaskType.text_classification + + response = mocked_client.post( + f"/api/datasets?workspace={ws}", + json=request, + ) + assert response.status_code == 409, response.json() + + response = mocked_client.post( + f"/api/datasets", + json=request, + ) + + assert response.status_code == 200, response.json() + dataset = Dataset.parse_obj(response.json()) + assert dataset.created_by == "rubrix" + assert dataset.name == dataset_name + assert dataset.owner == "rubrix" + assert dataset.task == TaskType.text_classification + + def test_dataset_naming_validation(mocked_client): request = TextClassificationBulkData(records=[]) dataset = "Wrong dataset name" @@ -166,8 +211,11 @@ def test_open_and_close_dataset(mocked_client): ) -def delete_dataset(client, dataset): - assert client.delete(f"/api/datasets/{dataset}").status_code == 200 +def delete_dataset(client, dataset, workspace: Optional[str] = None): + url = f"/api/datasets/{dataset}" + if workspace: + url += f"?workspace={workspace}" + assert client.delete(url).status_code == 200 def create_mock_dataset(client, dataset):