diff --git a/docs/_source/getting_started/installation/server_configuration.md b/docs/_source/getting_started/installation/server_configuration.md index 170d32d67d..de318d0e9a 100644 --- a/docs/_source/getting_started/installation/server_configuration.md +++ b/docs/_source/getting_started/installation/server_configuration.md @@ -23,7 +23,7 @@ You can set following environment variables to further configure your server and ### Server -- `ELASTICSEARCH`: URL of the connection endpoint of the Elasticsearch instance (Default: `http://localhost:9200`). +- `ARGILLA_ELASTICSEARCH`: URL of the connection endpoint of the Elasticsearch instance (Default: `http://localhost:9200`). - `ARGILLA_ELASTICSEARCH_SSL_VERIFY`: If "False", disables SSL certificate verification when connection to the Elasticsearch backend. @@ -35,7 +35,9 @@ You can set following environment variables to further configure your server and - `ARGILLA_EXACT_ES_SEARCH_ANALYZER`: Default analyzer for `*.exact` fields in textual information (Default: "whitespace"). -- `METADATA_FIELDS_LIMIT`: Max number of fields in the metadata (Default: 50, max: 100). +- `ARGILLA_METADATA_FIELDS_LIMIT`: Max number of fields in the metadata (Default: 50, max: 100). + +- `ARGILLA_METADATA_FIELD_LENGTH`: Max length supported for the string metadata fields. Higher values will be truncated. Abusing this may lead to Elastic performance issues (Default: 128). - `CORS_ORIGINS`: List of host patterns for CORS origin access. diff --git a/src/argilla/_constants.py b/src/argilla/_constants.py index 066f7f8711..965deecbeb 100644 --- a/src/argilla/_constants.py +++ b/src/argilla/_constants.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -MAX_KEYWORD_LENGTH = 128 +DEFAULT_MAX_KEYWORD_LENGTH = 128 API_KEY_HEADER_NAME = "X-Argilla-Api-Key" diff --git a/src/argilla/_messages.py b/src/argilla/_messages.py new file mode 100644 index 0000000000..5972504338 --- /dev/null +++ b/src/argilla/_messages.py @@ -0,0 +1,18 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARGILLA_METADATA_FIELD_WARNING_MESSAGE = ( + "You can configure this length in the server with the ARGILLA_METADATA_FIELD_LENGTH " + "environment variable. Note that, setting this too high may lead to Elastic performance issues." +) diff --git a/src/argilla/client/models.py b/src/argilla/client/models.py index 1d3f12353d..f8f3b0e746 100644 --- a/src/argilla/client/models.py +++ b/src/argilla/client/models.py @@ -26,8 +26,8 @@ from deprecated import deprecated from pydantic import BaseModel, Field, PrivateAttr, root_validator, validator -from argilla._constants import MAX_KEYWORD_LENGTH -from argilla.utils import limit_value_length +from argilla import _messages +from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.utils.span_utils import SpanUtils _LOGGER = logging.getLogger(__name__) @@ -37,16 +37,26 @@ class _Validators(BaseModel): """Base class for our record models that takes care of general validations""" @validator("metadata", check_fields=False) - def _check_value_length(cls, v): - """Checks metadata values length and apply value truncation for large values""" - new_metadata = limit_value_length(v, max_length=MAX_KEYWORD_LENGTH) - if new_metadata != v: - warnings.warn( - "Some metadata values exceed the max length. Those values will be" - f" truncated by keeping only the last {MAX_KEYWORD_LENGTH} characters." + def _check_value_length(cls, metadata): + """Checks metadata values length and warn message for large values""" + if not metadata: + return metadata + + default_length_exceeded = False + for v in metadata.values(): + if isinstance(v, str) and len(v) > DEFAULT_MAX_KEYWORD_LENGTH: + default_length_exceeded = True + break + + if default_length_exceeded: + message = ( + "Some metadata values could exceed the max length. For those cases, values will be" + f" truncated by keeping only the last {DEFAULT_MAX_KEYWORD_LENGTH} characters. " + + _messages.ARGILLA_METADATA_FIELD_WARNING_MESSAGE ) + warnings.warn(message, UserWarning) - return new_metadata + return metadata @validator("metadata", check_fields=False) def _none_to_empty_dict(cls, v): diff --git a/src/argilla/client/sdk/token_classification/models.py b/src/argilla/client/sdk/token_classification/models.py index 293a6b8c6e..44b3a5809b 100644 --- a/src/argilla/client/sdk/token_classification/models.py +++ b/src/argilla/client/sdk/token_classification/models.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, validator -from argilla._constants import MAX_KEYWORD_LENGTH +from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.client.models import ( TokenClassificationRecord as ClientTokenClassificationRecord, ) @@ -35,7 +35,7 @@ class EntitySpan(BaseModel): start: int end: int - label: str = Field(min_length=1, max_length=MAX_KEYWORD_LENGTH) + label: str = Field(min_length=1, max_length=DEFAULT_MAX_KEYWORD_LENGTH) score: float = Field(default=1.0, ge=0.0, le=1.0) diff --git a/src/argilla/server/daos/backend/elasticsearch.py b/src/argilla/server/daos/backend/elasticsearch.py index 721c2455c7..600b312625 100644 --- a/src/argilla/server/daos/backend/elasticsearch.py +++ b/src/argilla/server/daos/backend/elasticsearch.py @@ -930,7 +930,8 @@ def _configure_metadata_fields(self, id: str, metadata_values: Dict[str, Any]): def check_metadata_length(metadata_length: int = 0): if metadata_length > settings.metadata_fields_limit: raise MetadataLimitExceededError( - length=metadata_length, limit=settings.metadata_fields_limit + length=metadata_length, + limit=settings.metadata_fields_limit, ) def detect_nested_type(v: Any) -> bool: diff --git a/src/argilla/server/daos/backend/mappings/helpers.py b/src/argilla/server/daos/backend/mappings/helpers.py index 2e3ce7acdf..74514b6c31 100644 --- a/src/argilla/server/daos/backend/mappings/helpers.py +++ b/src/argilla/server/daos/backend/mappings/helpers.py @@ -14,7 +14,6 @@ from typing import Any, Dict, List -from argilla._constants import MAX_KEYWORD_LENGTH from argilla.server.settings import settings EXTENDED_ANALYZER_REF = "extended_analyzer" @@ -26,12 +25,12 @@ class mappings: @staticmethod - def keyword_field(enable_text_search: bool = False): + def keyword_field( + enable_text_search: bool = False, + ): """Mappings config for keyword field""" mapping = { "type": "keyword", - # TODO: Use environment var and align with fields validators - "ignore_above": MAX_KEYWORD_LENGTH, } if enable_text_search: text_field = mappings.text_field() @@ -41,14 +40,15 @@ def keyword_field(enable_text_search: bool = False): @staticmethod def path_match_keyword_template( - path: str, enable_text_search_in_keywords: bool = False + path: str, + enable_text_search_in_keywords: bool = False, ): """Dynamic template mappings config for keyword field based on path match""" return { "path_match": path, "match_mapping_type": "string", "mapping": mappings.keyword_field( - enable_text_search=enable_text_search_in_keywords + enable_text_search=enable_text_search_in_keywords, ), } @@ -167,7 +167,8 @@ def dynamic_metrics_text(): def dynamic_metadata_text(): return { "metadata.*": mappings.path_match_keyword_template( - path="metadata.*", enable_text_search_in_keywords=True + path="metadata.*", + enable_text_search_in_keywords=True, ) } diff --git a/src/argilla/server/daos/models/records.py b/src/argilla/server/daos/models/records.py index a9aee58ed9..67de3d2eb2 100644 --- a/src/argilla/server/daos/models/records.py +++ b/src/argilla/server/daos/models/records.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from datetime import datetime from typing import Any, Dict, Generic, List, Optional, TypeVar, Union from uuid import uuid4 @@ -19,10 +20,11 @@ from pydantic import BaseModel, Field, root_validator, validator from pydantic.generics import GenericModel -from argilla._constants import MAX_KEYWORD_LENGTH +from argilla import _messages from argilla.server.commons.models import PredictionStatus, TaskStatus, TaskType from argilla.server.daos.backend.search.model import BackendRecordsQuery, SortConfig from argilla.server.helpers import flatten_dict +from argilla.server.settings import settings from argilla.utils import limit_value_length @@ -138,7 +140,17 @@ def flatten_metadata(cls, metadata: Dict[str, Any]): """ if metadata: metadata = flatten_dict(metadata, drop_empty=True) - metadata = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH) + new_metadata = limit_value_length( + data=metadata, + max_length=settings.metadata_field_length, + ) + message = ( + "Some metadata values exceed the max length. Those values will be" + f" truncated by keeping only the last {settings.metadata_field_length} characters. " + + _messages.ARGILLA_METADATA_FIELD_WARNING_MESSAGE + ) + warnings.warn(message, UserWarning) + metadata = new_metadata return metadata @classmethod diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py index 3fb53fa65d..89884ad016 100644 --- a/src/argilla/server/services/tasks/text_classification/model.py +++ b/src/argilla/server/services/tasks/text_classification/model.py @@ -18,7 +18,7 @@ from pydantic import BaseModel, Field, root_validator, validator -from argilla._constants import MAX_KEYWORD_LENGTH +from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.server.commons.models import PredictionStatus, TaskStatus, TaskType from argilla.server.helpers import flatten_dict from argilla.server.services.datasets import ServiceBaseDataset @@ -97,9 +97,9 @@ class ClassPrediction(BaseModel): @validator("class_label") def check_label_length(cls, class_label): if isinstance(class_label, str): - assert 1 <= len(class_label) <= MAX_KEYWORD_LENGTH, ( - f"Class name '{class_label}' exceeds max length of {MAX_KEYWORD_LENGTH}" - if len(class_label) > MAX_KEYWORD_LENGTH + assert 1 <= len(class_label) <= DEFAULT_MAX_KEYWORD_LENGTH, ( + f"Class name '{class_label}' exceeds max length of {DEFAULT_MAX_KEYWORD_LENGTH}" + if len(class_label) > DEFAULT_MAX_KEYWORD_LENGTH else f"Class name must not be empty" ) return class_label diff --git a/src/argilla/server/services/tasks/token_classification/model.py b/src/argilla/server/services/tasks/token_classification/model.py index 62e50e499d..3320f991b2 100644 --- a/src/argilla/server/services/tasks/token_classification/model.py +++ b/src/argilla/server/services/tasks/token_classification/model.py @@ -12,14 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import typing -from collections import defaultdict -from datetime import datetime from typing import Any, Dict, List, Optional, Set, Tuple from pydantic import BaseModel, Field, validator -from argilla._constants import MAX_KEYWORD_LENGTH +from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.server.commons.models import PredictionStatus, TaskType from argilla.server.services.datasets import ServiceBaseDataset from argilla.server.services.search.model import ( @@ -57,7 +54,7 @@ class EntitySpan(BaseModel): start: int end: int - label: str = Field(min_length=1, max_length=MAX_KEYWORD_LENGTH) + label: str = Field(min_length=1, max_length=DEFAULT_MAX_KEYWORD_LENGTH) score: float = Field(default=1.0, ge=0.0, le=1.0) @validator("end") diff --git a/src/argilla/server/settings.py b/src/argilla/server/settings.py index 9701d6f762..66279bb00b 100644 --- a/src/argilla/server/settings.py +++ b/src/argilla/server/settings.py @@ -20,7 +20,9 @@ from typing import List, Optional from urllib.parse import urlparse -from pydantic import BaseSettings, Field, validator +from pydantic import BaseSettings, Field + +from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH class ApiSettings(BaseSettings): @@ -82,7 +84,15 @@ class ApiSettings(BaseSettings): es_records_index_replicas: int = 0 metadata_fields_limit: int = Field( - default=50, gt=0, le=100, description="Max number of fields in metadata" + default=50, + gt=0, + le=100, + description="Max number of fields in metadata", + ) + metadata_field_length: int = Field( + default=DEFAULT_MAX_KEYWORD_LENGTH, + description="Max length supported for the string metadata fields." + " Values containing higher than this will be truncated", ) enable_telemetry: bool = True diff --git a/tests/client/test_dataset.py b/tests/client/test_dataset.py index 7aaee7e916..137a3e55d2 100644 --- a/tests/client/test_dataset.py +++ b/tests/client/test_dataset.py @@ -260,7 +260,10 @@ def test_to_from_datasets(self, records, request): "metrics", ] assert dataset_ds.features["prediction"] == [ - {"label": datasets.Value("string"), "score": datasets.Value("float64")} + { + "label": datasets.Value("string"), + "score": datasets.Value("float64"), + } ] dataset = ar.DatasetForTextClassification.from_datasets(dataset_ds) diff --git a/tests/client/test_models.py b/tests/client/test_models.py index 4101cd992c..520453dccb 100644 --- a/tests/client/test_models.py +++ b/tests/client/test_models.py @@ -21,7 +21,7 @@ import pytest from pydantic import ValidationError -from argilla._constants import MAX_KEYWORD_LENGTH +from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.client.models import ( Text2TextRecord, TextClassificationRecord, @@ -202,20 +202,28 @@ def test_token_classification_prediction_validator(prediction, expected): def test_text_classification_record_none_inputs(): """Test validation error for None in inputs""" with pytest.raises(ValidationError): - TextClassificationRecord(inputs={"text": None}) + TextClassificationRecord.parse_obj(dict(inputs={"text": None})) def test_metadata_values_length(): text = "oh yeah!" - metadata = {"too_long": "a" * 200} + expected_length = 200 + metadata = {"too_long": "a" * expected_length} - record = TextClassificationRecord(inputs={"text": text}, metadata=metadata) - assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH + with pytest.warns(expected_warning=UserWarning): + record = TextClassificationRecord( + inputs={"text": text}, + metadata=metadata, + ) + assert len(record.metadata["too_long"]) == expected_length - record = TokenClassificationRecord( - text=text, tokens=text.split(), metadata=metadata - ) - assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH + with pytest.warns(expected_warning=UserWarning): + record = TokenClassificationRecord( + text=text, + tokens=text.split(), + metadata=metadata, + ) + assert len(record.metadata["too_long"]) == expected_length def test_model_serialization_with_numpy_nan(): diff --git a/tests/functional_tests/datasets/test_delete_records_from_datasets.py b/tests/functional_tests/datasets/test_delete_records_from_datasets.py index 5e04c201ef..20235429cb 100644 --- a/tests/functional_tests/datasets/test_delete_records_from_datasets.py +++ b/tests/functional_tests/datasets/test_delete_records_from_datasets.py @@ -100,7 +100,9 @@ def test_delete_records_with_unmatched_records(mocked_client): name=dataset, records=[ ar.TextClassificationRecord( - id=i, text="This is the text", metadata=dict(idx=i) + id=i, + text="This is the text", + metadata=dict(idx=i), ) for i in range(0, 50) ], diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py index b3fcb70068..ebbc61af47 100644 --- a/tests/server/text_classification/test_model.py +++ b/tests/server/text_classification/test_model.py @@ -15,7 +15,7 @@ import pytest from pydantic import ValidationError -from argilla._constants import MAX_KEYWORD_LENGTH +from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.server.apis.v0.models.text_classification import ( TextClassificationAnnotation, TextClassificationQuery, @@ -159,7 +159,7 @@ def test_too_long_metadata(): } ) - assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH + assert len(record.metadata["too_long"]) == DEFAULT_MAX_KEYWORD_LENGTH def test_too_long_label(): diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py index ebee8f6f88..fd3f6005dd 100644 --- a/tests/server/token_classification/test_model.py +++ b/tests/server/token_classification/test_model.py @@ -16,7 +16,7 @@ import pytest from pydantic import ValidationError -from argilla._constants import MAX_KEYWORD_LENGTH +from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.server.apis.v0.models.token_classification import ( TokenClassificationAnnotation, TokenClassificationQuery, @@ -160,7 +160,7 @@ def test_too_long_metadata(): } ) - assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH + assert len(record.metadata["too_long"]) == DEFAULT_MAX_KEYWORD_LENGTH def test_entity_label_too_long():