Skip to content

Commit

Permalink
feat: raise an error when dataset is ready and distribution settings …
Browse files Browse the repository at this point in the history
…are tried to be modified
  • Loading branch information
jfcalvo committed Jul 1, 2024
1 parent b23f3d8 commit c8aa1a9
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
4 changes: 3 additions & 1 deletion argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
)
from argilla_server.models.suggestions import SuggestionCreateWithRecordId
from argilla_server.search_engine import SearchEngine
from argilla_server.validators.datasets import DatasetCreateValidator
from argilla_server.validators.datasets import DatasetCreateValidator, DatasetUpdateValidator
from argilla_server.validators.responses import (
ResponseCreateValidator,
ResponseUpdateValidator,
Expand Down Expand Up @@ -171,6 +171,8 @@ async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset


async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> Dataset:
await DatasetUpdateValidator.validate(db, dataset, dataset_attrs)

return await dataset.update(db, **dataset_attrs)


Expand Down
17 changes: 14 additions & 3 deletions argilla-server/src/argilla_server/validators/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,27 @@

class DatasetCreateValidator:
@classmethod
async def validate(cls, db, dataset: Dataset) -> None:
async def validate(cls, db: AsyncSession, dataset: Dataset) -> None:
await cls._validate_workspace_is_present(db, dataset.workspace_id)
await cls._validate_name_is_not_duplicated(db, dataset.name, dataset.workspace_id)

@classmethod
async def _validate_workspace_is_present(cls, db, workspace_id: UUID) -> None:
async def _validate_workspace_is_present(cls, db: AsyncSession, workspace_id: UUID) -> None:
if await Workspace.get(db, workspace_id) is None:
raise UnprocessableEntityError(f"Workspace with id `{workspace_id}` not found")

@classmethod
async def _validate_name_is_not_duplicated(cls, db, name: str, workspace_id: UUID) -> None:
async def _validate_name_is_not_duplicated(cls, db: AsyncSession, name: str, workspace_id: UUID) -> None:
if await Dataset.get_by(db, name=name, workspace_id=workspace_id):
raise NotUniqueError(f"Dataset with name `{name}` already exists for workspace with id `{workspace_id}`")


class DatasetUpdateValidator:
@classmethod
async def validate(cls, db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> None:
cls._validate_distribution(dataset, dataset_attrs)

@classmethod
def _validate_distribution(cls, dataset: Dataset, dataset_attrs: dict) -> None:
if dataset.is_ready and dataset_attrs.get("distribution") is not None:
raise UnprocessableEntityError(f"Distribution settings cannot be modified for a published dataset")
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from uuid import UUID

import pytest
from argilla_server.enums import DatasetDistributionStrategy
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from httpx import AsyncClient

from tests.factories import DatasetFactory
Expand Down Expand Up @@ -72,6 +72,53 @@ async def test_update_dataset_without_distribution(self, async_client: AsyncClie
"min_submitted": 1,
}

async def test_update_dataset_without_distribution_for_published_dataset(
self, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"name": "Dataset updated name"},
)

assert response.status_code == 200
assert response.json()["distribution"] == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

assert dataset.name == "Dataset updated name"
assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_distribution_for_published_dataset(
self, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={
"distribution": {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 4,
},
},
)

assert response.status_code == 422
assert response.json() == {"detail": "Distribution settings cannot be modified for a published dataset"}

assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_distribution_with_invalid_strategy(
self, async_client: AsyncClient, owner_auth_header: dict
):
Expand Down

0 comments on commit c8aa1a9

Please sign in to comment.