Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Client): relax client constraints for rules management #2242

Merged
merged 3 commits into from Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 43 additions & 11 deletions src/argilla/client/client.py
Expand Up @@ -47,7 +47,11 @@
)
from argilla.client.sdk.client import AuthenticatedClient
from argilla.client.sdk.commons.api import async_bulk
from argilla.client.sdk.commons.errors import InputValueError
from argilla.client.sdk.commons.errors import (
AlreadyExistsApiError,
InputValueError,
NotFoundApiError,
)
from argilla.client.sdk.datasets import api as datasets_api
from argilla.client.sdk.datasets.models import CopyDatasetRequest, TaskType
from argilla.client.sdk.metrics import api as metrics_api
Expand Down Expand Up @@ -583,20 +587,48 @@ def compute_metric(
def add_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]):
"""Adds the dataset labeling rules"""
for rule in rules:
text_classification_api.add_dataset_labeling_rule(
self._client,
name=dataset,
rule=rule,
)

def update_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]):
try:
text_classification_api.add_dataset_labeling_rule(
self._client,
name=dataset,
rule=rule,
)
except AlreadyExistsApiError:
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
_LOGGER.warning(
f"Rule {rule} already exists. Please, update the rule instead."
)
except Exception as ex:
_LOGGER.warning(f"Cannot create rule {rule}: {ex}")

def update_dataset_labeling_rules(
self,
dataset: str,
rules: List[LabelingRule],
):
"""Updates the dataset labeling rules"""
for rule in rules:
text_classification_api.update_dataset_labeling_rule(
self._client, name=dataset, rule=rule
)
try:
text_classification_api.update_dataset_labeling_rule(
self._client,
name=dataset,
rule=rule,
)
except NotFoundApiError:
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
_LOGGER.info(f"Rule {rule} does not exists, creating...")
text_classification_api.add_dataset_labeling_rule(
self._client, name=dataset, rule=rule
)
except Exception as ex:
_LOGGER.warning(f"Cannot update rule {rule}: {ex}")

def delete_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]):
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
for rule in rules:
try:
text_classification_api.delete_dataset_labeling_rule(
self._client, name=dataset, rule=rule
)
except Exception as ex:
_LOGGER.warning(f"Cannot delete rule {rule}: {ex}")
"""Deletes the dataset labeling rules"""
for rule in rules:
text_classification_api.delete_dataset_labeling_rule(
Expand Down
27 changes: 27 additions & 0 deletions tests/labeling/text_classification/test_rule.py
Expand Up @@ -150,6 +150,33 @@ def test_call(monkeypatch, mocked_client, log_dataset):
assert rule(records[1]) is None


def test_add_duplicated_rule(
mocked_client,
log_dataset,
):
rules = [
Rule(query="lab", label="DD"),
Rule(query="lab", label="EF"),
]
add_rules(log_dataset, rules)
new_rules = load_rules(log_dataset)
assert len(new_rules) == 1, new_rules
assert new_rules[0].label == "DD" and new_rules[0].query == "lab"


def test_create_rules_with_update(
mocked_client,
log_dataset,
):
rules = [Rule(query="lab", label="DD"), Rule(query="ob", label="EF")]
update_rules(log_dataset, rules)

new_rules = load_rules(log_dataset)
assert [{"query": r.query, "label": r.label} for r in rules] == [
{"query": r.query, "label": r.label} for r in new_rules
]


def test_load_rules(mocked_client, log_dataset):

mocked_client.post(
Expand Down