diff --git a/src/rubrix/client/models.py b/src/rubrix/client/models.py index b1465b71c8..36b8d28698 100644 --- a/src/rubrix/client/models.py +++ b/src/rubrix/client/models.py @@ -21,12 +21,63 @@ import warnings from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, root_validator, validator from rubrix._constants import MAX_KEYWORD_LENGTH from rubrix.server.commons.helpers import limit_value_length +class _RootValidators(BaseModel): + """Base class for our record models that takes care of root validations""" + + @root_validator + def _check_value_length(cls, values): + """Checks metadata values length and apply value truncation for large values""" + new_metadata = limit_value_length( + values["metadata"], max_length=MAX_KEYWORD_LENGTH + ) + if new_metadata != values["metadata"]: + warnings.warn( + "Some metadata values exceed the max length. " + f"Those values will be truncated by keeping only the last {MAX_KEYWORD_LENGTH} characters." + ) + values["metadata"] = new_metadata + + return values + + @root_validator + def _check_agents(cls, values): + """Triggers a warning when ONLY agents are provided""" + if ( + values.get("annotation_agent") is not None + and values.get("annotation") is None + ): + warnings.warn( + "You provided an `annotation_agent`, but no `annotation`. The `annotation_agent` will not be logged to the server." + ) + if ( + values.get("prediction_agent") is not None + and values.get("prediction") is None + ): + warnings.warn( + "You provided an `prediction_agent`, but no `prediction`. The `prediction_agent` will not be logged to the server." + ) + + return values + + @root_validator + def _check_and_update_status(cls, values): + """Updates the status if an annotation is provided and no status is specified.""" + values["status"] = values.get("status") or ( + "Default" if values.get("annotation") is None else "Validated" + ) + + return values + + class Config: + extra = "forbid" + + class BulkResponse(BaseModel): """Summary response when logging records to the Rubrix server. @@ -55,7 +106,7 @@ class TokenAttributions(BaseModel): attributions: Dict[str, float] = Field(default_factory=dict) -class TextClassificationRecord(BaseModel): +class TextClassificationRecord(_RootValidators): """Record for text classification Args: @@ -64,10 +115,10 @@ class TextClassificationRecord(BaseModel): prediction: A list of tuples containing the predictions for the record. The first entry of the tuple is the predicted label, the second entry is its corresponding score. - annotation: - A string or a list of strings (multilabel) corresponding to the annotation (gold label) for the record. prediction_agent: Name of the prediction agent. By default, this is set to the hostname of your machine. + annotation: + A string or a list of strings (multilabel) corresponding to the annotation (gold label) for the record. annotation_agent: Name of the prediction agent. By default, this is set to the hostname of your machine. multi_label: @@ -99,17 +150,18 @@ class TextClassificationRecord(BaseModel): inputs: Union[str, List[str], Dict[str, Union[str, List[str]]]] prediction: Optional[List[Tuple[str, float]]] = None - annotation: Optional[Union[str, List[str]]] = None prediction_agent: Optional[str] = None + annotation: Optional[Union[str, List[str]]] = None annotation_agent: Optional[str] = None - multi_label: bool = False + multi_label: bool = False explanation: Optional[Dict[str, List[TokenAttributions]]] = None id: Optional[Union[int, str]] = None metadata: Dict[str, Any] = Field(default_factory=dict) status: Optional[str] = None event_timestamp: Optional[datetime.datetime] = None + metrics: Optional[Dict[str, Any]] = None @validator("inputs", pre=True) @@ -119,20 +171,8 @@ def input_as_dict(cls, inputs): return inputs return dict(text=inputs) - @validator("metadata", pre=True) - def check_value_length(cls, metadata): - return _limit_metadata_values(metadata) - def __init__(self, *args, **kwargs): - """Custom init to handle dynamic defaults""" - # noinspection PyArgumentList - super().__init__(*args, **kwargs) - self.status = self.status or ( - "Default" if self.annotation is None else "Validated" - ) - - -class TokenClassificationRecord(BaseModel): +class TokenClassificationRecord(_RootValidators): """Record for a token classification task Args: @@ -145,11 +185,11 @@ class TokenClassificationRecord(BaseModel): A list of tuples containing the predictions for the record. The first entry of the tuple is the name of predicted entity, the second and third entry correspond to the start and stop character index of the entity. EXPERIMENTAL: The fourth entry is optional and corresponds to the score of the entity. + prediction_agent: + Name of the prediction agent. By default, this is set to the hostname of your machine. annotation: A list of tuples containing annotations (gold labels) for the record. The first entry of the tuple is the name of the entity, the second and third entry correspond to the start and stop char index of the entity. - prediction_agent: - Name of the prediction agent. By default, this is set to the hostname of your machine. annotation_agent: Name of the prediction agent. By default, this is set to the hostname of your machine. id: @@ -180,29 +220,19 @@ class TokenClassificationRecord(BaseModel): prediction: Optional[ List[Union[Tuple[str, int, int], Tuple[str, int, int, float]]] ] = None - annotation: Optional[List[Tuple[str, int, int]]] = None prediction_agent: Optional[str] = None + annotation: Optional[List[Tuple[str, int, int]]] = None annotation_agent: Optional[str] = None id: Optional[Union[int, str]] = None metadata: Dict[str, Any] = Field(default_factory=dict) status: Optional[str] = None event_timestamp: Optional[datetime.datetime] = None - metrics: Optional[Dict[str, Any]] = None - - @validator("metadata", pre=True) - def check_value_length(cls, metadata): - return _limit_metadata_values(metadata) - def __init__(self, *args, **kwargs): - """Custom init to handle dynamic defaults""" - super().__init__(*args, **kwargs) - self.status = self.status or ( - "Default" if self.annotation is None else "Validated" - ) + metrics: Optional[Dict[str, Any]] = None -class Text2TextRecord(BaseModel): +class Text2TextRecord(_RootValidators): """Record for a text to text task Args: @@ -211,10 +241,10 @@ class Text2TextRecord(BaseModel): prediction: A list of strings or tuples containing predictions for the input text. If tuples, the first entry is the predicted text, the second entry is its corresponding score. - annotation: - A string representing the expected output text for the given input text. prediction_agent: Name of the prediction agent. By default, this is set to the hostname of your machine. + annotation: + A string representing the expected output text for the given input text. annotation_agent: Name of the prediction agent. By default, this is set to the hostname of your machine. id: @@ -241,14 +271,15 @@ class Text2TextRecord(BaseModel): text: str prediction: Optional[List[Union[str, Tuple[str, float]]]] = None - annotation: Optional[str] = None prediction_agent: Optional[str] = None + annotation: Optional[str] = None annotation_agent: Optional[str] = None id: Optional[Union[int, str]] = None metadata: Dict[str, Any] = Field(default_factory=dict) status: Optional[str] = None event_timestamp: Optional[datetime.datetime] = None + metrics: Optional[Dict[str, Any]] = None @validator("prediction") @@ -262,27 +293,5 @@ def prediction_as_tuples( return prediction return [(text, 1.0) for text in prediction] - @validator("metadata", pre=True) - def check_value_length(cls, metadata): - return _limit_metadata_values(metadata) - - def __init__(self, *args, **kwargs): - """Custom init to handle dynamic defaults""" - super().__init__(*args, **kwargs) - self.status = self.status or ( - "Default" if self.annotation is None else "Validated" - ) - - -def _limit_metadata_values(metadata: Dict[str, Any]) -> Dict[str, Any]: - """Checks metadata values length and apply value truncation for large values""" - new_value = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH) - if new_value != metadata: - warnings.warn( - "Some metadata values exceed the max length. " - f"Those values will be truncated by keeping only the last {MAX_KEYWORD_LENGTH} characters." - ) - return new_value - Record = Union[TextClassificationRecord, TokenClassificationRecord, Text2TextRecord] diff --git a/tests/client/test_models.py b/tests/client/test_models.py index a14c06ff0b..1e3133090d 100644 --- a/tests/client/test_models.py +++ b/tests/client/test_models.py @@ -13,14 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from typing import Any, Optional import numpy import pytest from pydantic import ValidationError -from rubrix._constants import MAX_KEYWORD_LENGTH -from rubrix.client.models import Text2TextRecord, TextClassificationRecord -from rubrix.client.models import TokenClassificationRecord +from rubrix._constants import MAX_KEYWORD_LENGTH +from rubrix.client.models import ( + Text2TextRecord, + TextClassificationRecord, + TokenClassificationRecord, + _RootValidators, +) @pytest.mark.parametrize( @@ -92,3 +97,35 @@ def test_model_serialization_with_numpy_nan(): ) json_record = json.loads(record.json()) + + +def test_warning_when_only_agent(): + class MockRecord(_RootValidators): + prediction: Optional[Any] = None + prediction_agent: Optional[str] = None + annotation: Optional[Any] = None + annotation_agent: Optional[str] = None + metadata: Optional[Any] = None + status: Optional[str] = None + + with pytest.warns( + UserWarning, match="`prediction_agent` will not be logged to the server." + ): + MockRecord(prediction_agent="mock") + with pytest.warns( + UserWarning, match="`annotation_agent` will not be logged to the server." + ): + MockRecord(annotation_agent="mock") + + +def test_forbid_extra(): + class MockRecord(_RootValidators): + prediction: Optional[Any] = None + prediction_agent: Optional[str] = None + annotation: Optional[Any] = None + annotation_agent: Optional[str] = None + metadata: Optional[Any] = None + status: Optional[str] = None + + with pytest.raises(ValidationError): + MockRecord(extra="mock")