diff --git a/src/rubrix/__init__.py b/src/rubrix/__init__.py index 2651bca4e4..bfc3f8af9e 100644 --- a/src/rubrix/__init__.py +++ b/src/rubrix/__init__.py @@ -24,7 +24,7 @@ from rubrix.logging import configure_logging as _configure_logging from . import _version -from .utils import _LazyRubrixModule +from .utils import LazyRubrixModule as _LazyRubrixModule __version__ = _version.version diff --git a/src/rubrix/client/datasets.py b/src/rubrix/client/datasets.py index 133cfa4dc2..24ca7dac8c 100644 --- a/src/rubrix/client/datasets.py +++ b/src/rubrix/client/datasets.py @@ -28,6 +28,7 @@ TokenClassificationRecord, ) from rubrix.client.sdk.datasets.models import TaskType +from rubrix.utils.span_utils import SpanUtils _LOGGER = logging.getLogger(__name__) @@ -877,12 +878,11 @@ def _prepare_for_training_with_transformers(self): class_tags = datasets.ClassLabel(names=class_tags) def spans2iob(example): - r = TokenClassificationRecord( - text=example["text"], - tokens=example["tokens"], - annotation=self.__entities_to_tuple__(example["annotation"]), - ) - return class_tags.str2int(r.spans2iob(r.annotation)) + span_utils = SpanUtils(example["text"], example["tokens"]) + entity_spans = self.__entities_to_tuple__(example["annotation"]) + tags = span_utils.to_tags(entity_spans) + + return class_tags.str2int(tags) ds = ( self.to_datasets() diff --git a/src/rubrix/client/models.py b/src/rubrix/client/models.py index 49e11af7d7..baae6b47b7 100644 --- a/src/rubrix/client/models.py +++ b/src/rubrix/client/models.py @@ -20,7 +20,6 @@ import datetime import logging import warnings -from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple, Union import pandas as pd @@ -28,6 +27,7 @@ from rubrix._constants import MAX_KEYWORD_LENGTH from rubrix.utils import limit_value_length +from rubrix.utils.span_utils import SpanUtils _LOGGER = logging.getLogger(__name__) @@ -295,8 +295,7 @@ class TokenClassificationRecord(_Validators): metrics: Optional[Dict[str, Any]] = None search_keywords: Optional[List[str]] = None - __chars2tokens__: Dict[int, int] = PrivateAttr(default=None) - __tokens2chars__: Dict[int, Tuple[int, int]] = PrivateAttr(default=None) + _span_utils: SpanUtils = PrivateAttr() def __init__( self, @@ -320,45 +319,18 @@ def __init__( text = " ".join(tokens) super().__init__(text=text, tokens=tokens, **data) + + self._span_utils = SpanUtils(self.text, self.tokens) + + if self.annotation: + self.annotation = self._validate_spans(self.annotation) + if self.prediction: + self.prediction = self._validate_spans(self.prediction) + if self.annotation and tags: _LOGGER.warning("Annotation already provided, `tags` won't be used") - return - if tags: - self.annotation = self.__tags2entities__(tags) - - def __tags2entities__(self, tags: List[str]) -> List[Tuple[str, int, int]]: - idx = 0 - entities = [] - entity_starts = False - while idx < len(tags): - tag = tags[idx] - if tag == "O": - entity_starts = False - if tag != "O": - prefix, entity = tag.split("-") - if prefix in ["B", "U"]: - if prefix == "B": - entity_starts = True - char_start, char_end = self.token_span(token_idx=idx) - entities.append( - {"entity": entity, "start": char_start, "end": char_end + 1} - ) - elif prefix in ["I", "L"]: - if not entity_starts: - _LOGGER.warning( - "Detected non-starting tag and first entity token was not found." - f"Assuming {tag} as first entity token" - ) - entity_starts = True - char_start, char_end = self.token_span(token_idx=idx) - entities.append( - {"entity": entity, "start": char_start, "end": char_end + 1} - ) - - _, char_end = self.token_span(token_idx=idx) - entities[-1]["end"] = char_end + 1 - idx += 1 - return [(value["entity"], value["start"], value["end"]) for value in entities] + elif tags: + self.annotation = self._span_utils.from_tags(tags) def __setattr__(self, name: str, value: Any): """Make text and tokens immutable""" @@ -366,6 +338,30 @@ def __setattr__(self, name: str, value: Any): raise AttributeError(f"You cannot assign a new value to `{name}`") super().__setattr__(name, value) + def _validate_spans( + self, spans: List[Tuple[str, int, int]] + ) -> List[Tuple[str, int, int]]: + """Validates the entity spans with respect to the tokens. + + If necessary, also performs an automatic correction of the spans. + + Args: + spans: The entity spans to validate. + + Returns: + The optionally corrected spans. + + Raises: + ValidationError: If spans are not valid or misaligned. + """ + try: + self._span_utils.validate(spans) + except ValueError: + spans = self._span_utils.correct(spans) + self._span_utils.validate(spans) + + return spans + @validator("tokens", pre=True) def _normalize_tokens(cls, value): if isinstance(value, list): @@ -375,7 +371,7 @@ def _normalize_tokens(cls, value): return value @validator("prediction") - def add_default_score( + def _add_default_score( cls, prediction: Optional[ List[Union[Tuple[str, int, int], Tuple[str, int, int, Optional[float]]]] @@ -391,103 +387,64 @@ def add_default_score( for pred in prediction ] - @staticmethod - def __build_indices_map__( - text: str, tokens: Tuple[str, ...] - ) -> Tuple[Dict[int, int], Dict[int, Tuple[int, int]]]: - """ - Build the indices mapping between text characters and tokens where belongs to, - and vice versa. - - chars2tokens index contains is the token idx where i char is contained (if any). - - Out-of-token characters won't be included in this map, - so access should be using ``chars2tokens_map.get(i)`` - instead of ``chars2tokens_map[i]``. - - """ - - def chars2tokens_index(text_, tokens_): - chars_map = {} - current_token = 0 - current_token_char_start = 0 - for idx, char in enumerate(text_): - relative_idx = idx - current_token_char_start - if ( - relative_idx < len(tokens_[current_token]) - and char == tokens_[current_token][relative_idx] - ): - chars_map[idx] = current_token - elif ( - current_token + 1 < len(tokens_) - and relative_idx >= len(tokens_[current_token]) - and char == tokens_[current_token + 1][0] - ): - current_token += 1 - current_token_char_start += relative_idx - chars_map[idx] = current_token - return chars_map - - def tokens2chars_index( - chars2tokens: Dict[int, int] - ) -> Dict[int, Tuple[int, int]]: - tokens2chars_map = defaultdict(list) - for c, t in chars2tokens.items(): - tokens2chars_map[t].append(c) - - return { - token_idx: (min(chars), max(chars)) - for token_idx, chars in tokens2chars_map.items() - } - - chars2tokens_idx = chars2tokens_index(text_=text, tokens_=tokens) - return chars2tokens_idx, tokens2chars_index(chars2tokens_idx) + @validator("text") + def _check_if_empty_after_strip(cls, text: str): + assert text.strip(), "The provided `text` contains only whitespaces." + return text + + @property + def __chars2tokens__(self) -> Dict[int, int]: + """DEPRECATED, please use the ``rubrix.utils.span_utils.SpanUtils.chars_to_token_idx`` attribute.""" + warnings.warn( + "The `__chars2tokens__` attribute is deprecated and will be removed in a future version. " + "Please use the `rubrix.utils.span_utils.SpanUtils.char_to_token_idx` attribute instead.", + FutureWarning, + ) + return self._span_utils.char_to_token_idx + + @property + def __tokens2chars__(self) -> Dict[int, Tuple[int, int]]: + """DEPRECATED, please use the ``rubrix.utils.span_utils.SpanUtils.chars_to_token_idx`` attribute.""" + warnings.warn( + "The `__tokens2chars__` attribute is deprecated and will be removed in a future version. " + "Please use the `rubrix.utils.span_utils.SpanUtils.token_to_char_idx` attribute instead.", + FutureWarning, + ) + return self._span_utils.token_to_char_idx def char_id2token_id(self, char_idx: int) -> Optional[int]: - """ - Given a character id, returns the token id it belongs to. - ``None`` otherwise - """ - - if self.__chars2tokens__ is None: - self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__( - self.text, tuple(self.tokens) - ) - return self.__chars2tokens__.get(char_idx) + """DEPRECATED, please use the ``rubrix.utisl.span_utils.SpanUtils.char_to_token_idx`` dict instead.""" + warnings.warn( + "The `char_id2token_id` method is deprecated and will be removed in a future version. " + "Please use the `rubrix.utils.span_utils.SpanUtils.char_to_token_idx` dict instead.", + FutureWarning, + ) + return self._span_utils.char_to_token_idx.get(char_idx) def token_span(self, token_idx: int) -> Tuple[int, int]: - """ - Given a token id, returns the start and end characters. - Raises an ``IndexError`` if token id is out of tokens list indices - """ - if self.__tokens2chars__ is None: - self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__( - self.text, tuple(self.tokens) - ) - if token_idx not in self.__tokens2chars__: + """DEPRECATED, please use the ``rubrix.utisl.span_utils.SpanUtils.token_to_char_idx`` dict instead.""" + warnings.warn( + "The `token_span` method is deprecated and will be removed in a future version. " + "Please use the `rubrix.utils.span_utils.SpanUtils.token_to_char_idx` dict instead.", + FutureWarning, + ) + if token_idx not in self._span_utils.token_to_char_idx: raise IndexError(f"Token id {token_idx} out of bounds") - return self.__tokens2chars__[token_idx] + return self._span_utils.token_to_char_idx[token_idx] def spans2iob( self, spans: Optional[List[Tuple[str, int, int]]] = None ) -> Optional[List[str]]: - """Build the iob tags sequence for a list of spans annoations""" + """DEPRECATED, please use the ``rubrix.utils.SpanUtils.to_tags()`` method.""" + warnings.warn( + "'spans2iob' is deprecated and will be removed in a future version. " + "Please use the `rubrix.utils.SpanUtils.to_tags()` method instead, and adapt your code accordingly.", + FutureWarning, + ) if spans is None: return None - - tags = ["O"] * len(self.tokens) - for label, start, end in spans: - token_start = self.char_id2token_id(start) - token_end = self.char_id2token_id(end - 1) - assert ( - token_start is not None and token_end is not None - ), "Provided spans are missaligned at token level" - tags[token_start] = f"B-{label}" - for idx in range(token_start + 1, token_end + 1): - tags[idx] = f"I-{label}" - - return tags + return self._span_utils.to_tags(spans) class Text2TextRecord(_Validators): diff --git a/src/rubrix/server/apis/v0/models/token_classification.py b/src/rubrix/server/apis/v0/models/token_classification.py index 948fc559ab..b7e16146e2 100644 --- a/src/rubrix/server/apis/v0/models/token_classification.py +++ b/src/rubrix/server/apis/v0/models/token_classification.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, root_validator, validator @@ -35,6 +36,7 @@ from rubrix.server.services.tasks.token_classification.model import ( ServiceTokenClassificationDataset, ) +from rubrix.utils import SpanUtils class TokenClassificationAnnotation(_TokenClassificationAnnotation): diff --git a/src/rubrix/server/services/tasks/token_classification/metrics.py b/src/rubrix/server/services/tasks/token_classification/metrics.py index f574cde4c5..15f048faf3 100644 --- a/src/rubrix/server/services/tasks/token_classification/metrics.py +++ b/src/rubrix/server/services/tasks/token_classification/metrics.py @@ -6,8 +6,10 @@ from rubrix.server.services.metrics.models import CommonTasksMetrics from rubrix.server.services.tasks.token_classification.model import ( EntitySpan, + ServiceTokenClassificationAnnotation, ServiceTokenClassificationRecord, ) +from rubrix.utils import SpanUtils class F1Metric(ServicePythonMetric[ServiceTokenClassificationRecord]): @@ -210,7 +212,7 @@ def mention_tokens_length(entity: EntitySpan) -> int: [ token_idx for i in range(entity.start, entity.end) - for token_idx in [record.char_id2token_id(i)] + for token_idx in [record.span_utils.char_to_token_idx.get(i)] if token_idx is not None ] ) @@ -244,13 +246,15 @@ def build_tokens_metrics( idx=token_idx, value=token_value, char_start=char_start, - char_end=char_end, + # TODO(@frascuchon): Align char span definition to the entity level definition + # (char_end should be the next char after the token span boundaries). + char_end=char_end - 1, capitalness=cls.capitalness(token_value), - length=1 + (char_end - char_start), + length=char_end - char_start, tag=tags[token_idx] if tags else None, ) for token_idx, token_value in enumerate(record.tokens) - for char_start, char_end in [record.token_span(token_idx)] + for char_start, char_end in [record.span_utils.token_to_char_idx[token_idx]] ] @classmethod @@ -258,8 +262,9 @@ def record_metrics(cls, record: ServiceTokenClassificationRecord) -> Dict[str, A """Compute metrics at record level""" base_metrics = super(TokenClassificationMetrics, cls).record_metrics(record) - annotated_tags = record.annotated_iob_tags() or [] - predicted_tags = record.predicted_iob_tags() or [] + span_utils = SpanUtils(record.text, record.tokens) + annotated_tags = cls._compute_iob_tags(span_utils, record.annotation) or [] + predicted_tags = cls._compute_iob_tags(span_utils, record.prediction) or [] tokens_metrics = cls.build_tokens_metrics( record, predicted_tags or annotated_tags @@ -284,6 +289,26 @@ def record_metrics(cls, record: ServiceTokenClassificationRecord) -> Dict[str, A }, } + @staticmethod + def _compute_iob_tags( + span_utils: SpanUtils, + annotation: Optional[ServiceTokenClassificationAnnotation], + ) -> Optional[List[str]]: + """Helper method to compute IOB tags from entity spans + + Args: + span_utils: Helper class to perform the computation. + annotation: Contains the spans from which to compute the IOB tags. + + Returns: + The IOB tags or None if ``annotation`` is None. + """ + if annotation is None: + return None + + spans = [(ent.label, ent.start, ent.end) for ent in annotation.entities] + return span_utils.to_tags(spans) + metrics: ClassVar[List[ServiceBaseMetric]] = ( CommonTasksMetrics.metrics + [ diff --git a/src/rubrix/server/services/tasks/token_classification/model.py b/src/rubrix/server/services/tasks/token_classification/model.py index a66c014fb2..604a979670 100644 --- a/src/rubrix/server/services/tasks/token_classification/model.py +++ b/src/rubrix/server/services/tasks/token_classification/model.py @@ -30,6 +30,7 @@ ServiceBaseAnnotation, ServiceBaseRecord, ) +from rubrix.utils import SpanUtils PREDICTED_MENTIONS_ES_FIELD_NAME = "predicted_mentions" MENTIONS_ES_FIELD_NAME = "mentions" @@ -83,9 +84,7 @@ class ServiceTokenClassificationRecord( tokens: List[str] = Field(min_items=1) text: str = Field() _raw_text: Optional[str] = Field(alias="raw_text") - - __chars2tokens__: Dict[int, int] = None - __tokens2chars__: Dict[int, Tuple[int, int]] = None + _span_utils: SpanUtils # TODO: review this. _predicted: Optional[PredictionStatus] = Field(alias="predicted") @@ -109,118 +108,43 @@ def extended_fields(self) -> Dict[str, Any]: def __init__(self, **data): super().__init__(**data) - self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__() - - self.check_annotation(self.prediction) - self.check_annotation(self.annotation) - - def char_id2token_id(self, char_idx: int) -> Optional[int]: - return self.__chars2tokens__.get(char_idx) + self._span_utils = SpanUtils(self.text, self.tokens) - def token_span(self, token_idx: int) -> Tuple[int, int]: - if token_idx not in self.__tokens2chars__: - raise IndexError(f"Token id {token_idx} out of bounds") - return self.__tokens2chars__[token_idx] + if self.annotation: + self._validate_spans(self.annotation) + if self.prediction: + self._validate_spans(self.prediction) - def __build_indices_map__( - self, - ) -> Tuple[Dict[int, int], Dict[int, Tuple[int, int]]]: - """ - Build the indices mapping between text characters and tokens where belongs to, - and vice versa. + def _validate_spans(self, annotation: ServiceTokenClassificationAnnotation): + """Validates the spans with respect to the tokens. - chars2tokens index contains is the token idx where i char is contained (if any). + If necessary, also performs an automatic correction of the spans. - Out-of-token characters won't be included in this map, - so access should be using ``chars2tokens_map.get(i)`` - instead of ``chars2tokens_map[i]``. + Args: + span_utils: Helper class to perform the checks. + annotation: Contains the spans to validate. + Raises: + ValidationError: If spans are not valid or misaligned. """ - - def chars2tokens_index(): - def is_space_after_token(char, idx: int, chars_map) -> str: - return char == " " and idx - 1 in chars_map - - chars_map = {} - current_token = 0 - current_token_char_start = 0 - for idx, char in enumerate(self.text): - if is_space_after_token(char, idx, chars_map): - continue - relative_idx = idx - current_token_char_start - if ( - relative_idx < len(self.tokens[current_token]) - and char == self.tokens[current_token][relative_idx] - ): - chars_map[idx] = current_token - elif ( - current_token + 1 < len(self.tokens) - and relative_idx >= len(self.tokens[current_token]) - and char == self.tokens[current_token + 1][0] - ): - current_token += 1 - current_token_char_start += relative_idx - chars_map[idx] = current_token - - return chars_map - - def tokens2chars_index( - chars2tokens: Dict[int, int] - ) -> Dict[int, Tuple[int, int]]: - tokens2chars_map = defaultdict(list) - for c, t in chars2tokens.items(): - tokens2chars_map[t].append(c) - - return { - token_idx: (min(chars), max(chars)) - for token_idx, chars in tokens2chars_map.items() - } - - chars2tokens_idx = chars2tokens_index() - return chars2tokens_idx, tokens2chars_index(chars2tokens_idx) - - def check_annotation( - self, - annotation: Optional[ServiceTokenClassificationAnnotation], - ): - """Validates entities in terms of offset spans""" - - def adjust_span_bounds(start, end): - if start < 0: - start = 0 - if entity.end > len(self.text): - end = len(self.text) - while start <= len(self.text) and not self.text[start].strip(): - start += 1 - while not self.text[end - 1].strip(): - end -= 1 - return start, end - - if annotation: - for entity in annotation.entities: - entity.start, entity.end = adjust_span_bounds(entity.start, entity.end) - mention = self.text[entity.start : entity.end] - assert len(mention) > 0, f"Empty offset defined for entity {entity}" - - token_start = self.char_id2token_id(entity.start) - token_end = self.char_id2token_id(entity.end - 1) - - assert not ( - token_start is None or token_end is None - ), f"Provided entity span {self.text[entity.start: entity.end]} is not aligned with provided tokens." - "Some entity chars could be reference characters out of tokens" - - span_start, _ = self.token_span(token_start) - _, span_end = self.token_span(token_end) - - assert ( - self.text[span_start : span_end + 1] == mention - ), f"Defined offset [{self.text[entity.start: entity.end]}] is a misaligned entity mention" + spans = [(ent.label, ent.start, ent.end) for ent in annotation.entities] + try: + self._span_utils.validate(spans) + except ValueError: + corrected_spans = self._span_utils.correct(spans) + self._span_utils.validate(corrected_spans) + for ent, span in zip(annotation.entities, corrected_spans): + ent.start, ent.end = span[1], span[2] def task(cls) -> TaskType: """The record task type""" return TaskType.token_classification + @property + def span_utils(self) -> SpanUtils: + """Utility class for span operations.""" + return self._span_utils + @property def predicted(self) -> Optional[PredictionStatus]: if self.annotation and self.prediction: @@ -250,29 +174,6 @@ def scores(self) -> List[float]: def all_text(self) -> str: return self.text - def predicted_iob_tags(self) -> Optional[List[str]]: - if self.prediction is None: - return None - return self.spans2iob(self.prediction.entities) - - def annotated_iob_tags(self) -> Optional[List[str]]: - if self.annotation is None: - return None - return self.spans2iob(self.annotation.entities) - - def spans2iob(self, spans: List[EntitySpan]) -> Optional[List[str]]: - if spans is None: - return None - tags = ["O"] * len(self.tokens) - for entity in spans: - token_start = typing.cast(int, self.char_id2token_id(entity.start)) - token_end = typing.cast(int, self.char_id2token_id(entity.end - 1)) - tags[token_start] = f"B-{entity.label}" - for idx in range(token_start + 1, token_end + 1): - tags[idx] = f"I-{entity.label}" - - return tags - def predicted_mentions(self) -> List[Tuple[str, EntitySpan]]: return [ (mention, entity) diff --git a/src/rubrix/utils/__init__.py b/src/rubrix/utils/__init__.py new file mode 100644 index 0000000000..e355d29a87 --- /dev/null +++ b/src/rubrix/utils/__init__.py @@ -0,0 +1,16 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .span_utils import SpanUtils +from .utils import LazyRubrixModule, limit_value_length, setup_loop_in_thread diff --git a/src/rubrix/utils/span_utils.py b/src/rubrix/utils/span_utils.py new file mode 100644 index 0000000000..201039cb4d --- /dev/null +++ b/src/rubrix/utils/span_utils.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Optional, Tuple + + +class SpanUtils: + """Holds utility methods to work with a tokenized text and entity spans. + + Spans must be tuples containing the label (str), start char idx (int), and end char idx (int). + + Args: + text: The text the spans refer to. + tokens: The tokens of the text. + """ + + def __init__(self, text: str, tokens: List[str]): + self._text, self._tokens = text, tokens + + self._token_to_char_idx: Dict[int, Tuple[int, int]] = {} + self._start_to_token_idx: Dict[int, int] = {} + self._end_to_token_idx: Dict[int, int] = {} + self._char_to_token_idx: Dict[int, int] = {} + + end_idx = 0 + for idx, token in enumerate(tokens): + start_idx = text.find(token, end_idx) + if start_idx == -1: + raise ValueError(f"Token '{token}' not found in text: {text}") + end_idx = start_idx + len(token) + + self._token_to_char_idx[idx] = (start_idx, end_idx) + self._start_to_token_idx[start_idx] = idx + self._end_to_token_idx[end_idx] = idx + for i in range(start_idx, end_idx): + self._char_to_token_idx[i] = idx + + # convention: skip first white space after a token + try: + if text[end_idx] == " ": + end_idx += 1 + # reached end of text + except IndexError: + pass + + @property + def text(self) -> str: + """The text the spans refer to.""" + return self._text + + @property + def tokens(self) -> List[str]: + """The tokens of the text.""" + return self._tokens + + @property + def token_to_char_idx(self) -> Dict[int, Tuple[int, int]]: + """The token index to start/end char index mapping.""" + return self._token_to_char_idx + + @property + def char_to_token_idx(self) -> Dict[int, int]: + """The char index to token index mapping.""" + return self._char_to_token_idx + + def validate(self, spans: List[Tuple[str, int, int]]): + """Validates the alignment of span boundaries and tokens. + + Args: + spans: A list of spans. + + Raises: + ValueError: If a span is invalid, or if a span is not aligned with the tokens. + """ + not_valid_spans_errors, misaligned_spans_errors = [], [] + + for span in spans: + char_start, char_end = span[1], span[2] + if char_end - char_start < 1: + not_valid_spans_errors.append(span) + elif None in ( + self._start_to_token_idx.get(char_start), + self._end_to_token_idx.get(char_end), + ): + message = f"- [{self.text[char_start:char_end]}] defined in " + if char_start - 5 > 0: + message += "..." + message += self.text[max(char_start - 5, 0) : char_end + 5] + if char_end + 5 < len(self.text): + message += "..." + + misaligned_spans_errors.append(message) + + if not_valid_spans_errors or misaligned_spans_errors: + message = "" + if not_valid_spans_errors: + message += ( + f"Following entity spans are not valid: {not_valid_spans_errors}\n" + ) + + if misaligned_spans_errors: + spans = "\n".join(misaligned_spans_errors) + message += f"Following entity spans are not aligned with provided tokenization\n" + message += f"Spans:\n{spans}\n" + message += f"Tokens:\n{self.tokens}" + + raise ValueError(message) + + def correct(self, spans: List[Tuple[str, int, int]]) -> List[Tuple[str, int, int]]: + """Correct span boundaries for leading/trailing white spaces, new lines and tabs. + + Args: + spans: Spans to be corrected. + + Returns: + The corrected spans. + """ + corrected_spans = [] + for span in spans: + start, end = span[1], span[2] + + if start < 0: + start = 0 + if end > len(self.text): + end = len(self.text) + + while start <= len(self.text) and not self.text[start].strip(): + start += 1 + while not self.text[end - 1].strip(): + end -= 1 + + corrected_spans.append((span[0], start, end)) + + return corrected_spans + + def to_tags(self, spans: List[Tuple[str, int, int]]) -> List[str]: + """Convert spans to IOB tags. + + Args: + spans: Spans to transform into IOB tags. + + Returns: + The IOB tags. + + Raises: + ValueError: If spans overlap, the IOB format does not support overlapping spans. + """ + # check for overlapping spans + sorted_spans = sorted(spans, key=lambda x: x[1]) + for i in range(1, len(spans)): + if sorted_spans[i - 1][2] > sorted_spans[i][1]: + raise ValueError("IOB tags cannot handle overlapping spans!") + + tags = ["O"] * len(self.tokens) + for span in spans: + start_token_idx = self._start_to_token_idx[span[1]] + end_token_idx = self._end_to_token_idx[span[2]] + + tags[start_token_idx] = f"B-{span[0]}" + for token_idx in range(start_token_idx + 1, end_token_idx + 1): + tags[token_idx] = f"I-{span[0]}" + + return tags + + def from_tags(self, tags: List[str]) -> List[Tuple[str, int, int]]: + """Convert IOB or BILOU tags to spans. + + Overlapping spans are NOT supported! + + Args: + tags: The IOB or BILOU tags. + + Returns: + A list of spans. + + Raises: + ValueError: If the list of tags has not the same length as the list of tokens, + or tags are not in the IOB or BILOU format. + """ + + def get_prefix_and_entity(tag_str: str) -> Tuple[str, Optional[str]]: + if tag_str == "O": + return tag_str, None + splits = tag_str.split("-") + return splits[0], "-".join(splits[1:]) + + if len(tags) != len(self.tokens): + raise ValueError( + "The list of tags must have the same length as the list of tokens!" + ) + + spans, start_idx = [], None + for idx, tag in enumerate(tags): + prefix, entity = get_prefix_and_entity(tag) + + if prefix == "O": + continue + + if prefix == "U": + start_idx, end_idx = self._token_to_char_idx[idx] + spans.append((entity, start_idx, end_idx)) + start_idx = None + continue + + if prefix == "L": + # If no start prefix, we just assume "L" == "U": + if start_idx is None: + start_idx, end_idx = self._token_to_char_idx[idx] + else: + _, end_idx = self._token_to_char_idx[idx] + spans.append((entity, start_idx, end_idx)) + start_idx = None + continue + + if prefix == "B": + start_idx, end_idx = self._token_to_char_idx[idx] + elif prefix == "I": + # If "B" is missing, we just assume "I" starts the span + if start_idx is None: + start_idx = self._token_to_char_idx[idx][0] + end_idx = self._token_to_char_idx[idx][1] + else: + raise ValueError("Tags are not in the IOB or BILOU format!") + + try: + next_tag = tags[idx + 1] + # Reached last tag, add span + except IndexError: + spans.append((entity, start_idx, end_idx)) + break + + next_prefix, next_entity = get_prefix_and_entity(next_tag) + # span continues + if next_prefix in ["I", "L"] and next_entity == entity: + continue + # span ends + spans.append((entity, start_idx, end_idx)) + start_idx = None + + return spans diff --git a/src/rubrix/utils.py b/src/rubrix/utils/utils.py similarity index 99% rename from src/rubrix/utils.py rename to src/rubrix/utils/utils.py index bf8e47c3a2..cd46115ad3 100644 --- a/src/rubrix/utils.py +++ b/src/rubrix/utils/utils.py @@ -22,7 +22,7 @@ from typing import Any, Optional, Tuple -class _LazyRubrixModule(ModuleType): +class LazyRubrixModule(ModuleType): """Module class that surfaces all objects but only performs associated imports when the objects are requested. Shamelessly copied and adapted from the Hugging Face transformers implementation. diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 86d48bd542..3397e8c6e9 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -533,28 +533,34 @@ def test_load_as_pandas(mocked_client): assert [record.id for record in records] == [0, 1, 2, 3] -def test_token_classification_spans(mocked_client): - dataset = "test_token_classification_with_consecutive_spans" +@pytest.mark.parametrize( + "span,valid", + [ + ((1, 2), False), + ((0, 4), True), + ((0, 5), True), # automatic correction + ], +) +def test_token_classification_spans(span, valid): texto = "Esto es una prueba" - item = api.TokenClassificationRecord( - text=texto, - tokens=texto.split(), - prediction=[("test", 1, 2)], # Inicio y fin son consecutivos - prediction_agent="test", - ) - with pytest.raises( - Exception, match=r"Defined offset \[s\] is a misaligned entity mention" - ): - api.log(item, name=dataset) - - item.prediction = [("test", 0, 6)] - with pytest.raises( - Exception, match=r"Defined offset \[Esto e\] is a misaligned entity mention" - ): - api.log(item, name=dataset) - - item.prediction = [("test", 0, 4)] - api.log(item, name=dataset) + if valid: + rb.TokenClassificationRecord( + text=texto, + tokens=texto.split(), + prediction=[("test", *span)], + ) + else: + with pytest.raises( + ValueError, + match="Following entity spans are not aligned with provided tokenization\n" + r"Spans:\n- \[s\] defined in Esto es...\n" + r"Tokens:\n\['Esto', 'es', 'una', 'prueba'\]", + ): + rb.TokenClassificationRecord( + text=texto, + tokens=texto.split(), + prediction=[("test", *span)], + ) def test_load_text2text(mocked_client): diff --git a/tests/client/test_asgi.py b/tests/client/test_asgi.py index 6b157b0bdc..51fdae1e80 100644 --- a/tests/client/test_asgi.py +++ b/tests/client/test_asgi.py @@ -100,8 +100,8 @@ def mock_predict(request): return JSONResponse( content=[ [ - {"label": "fawn", "start": 1, "end": 10}, - {"label": "fobis", "start": 12, "end": 14}, + {"label": "fawn", "start": 0, "end": 3}, + {"label": "fobis", "start": 4, "end": 8}, ], [], ] diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py index 55499c7e6b..d2f6940a7f 100644 --- a/tests/functional_tests/test_log_for_token_classification.py +++ b/tests/functional_tests/test_log_for_token_classification.py @@ -12,7 +12,9 @@ def test_log_with_empty_text(mocked_client): text = " " rubrix.delete(dataset) - with pytest.raises(Exception, match="No text or empty text provided"): + with pytest.raises( + Exception, match="The provided `text` contains only whitespaces." + ): rubrix.log( TokenClassificationRecord(id=0, text=text, tokens=["a", "b", "c"]), name=dataset, diff --git a/tests/test_init.py b/tests/test_init.py index 3113533e7d..1522b7b93f 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -17,11 +17,11 @@ import sys from rubrix.logging import LoguruLoggerHandler -from rubrix.utils import _LazyRubrixModule +from rubrix.utils import LazyRubrixModule def test_lazy_module(): - assert isinstance(sys.modules["rubrix"], _LazyRubrixModule) + assert isinstance(sys.modules["rubrix"], LazyRubrixModule) def test_configure_logging_call(): diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000000..168dce9eb1 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/utils/test_span_utils.py b/tests/utils/test_span_utils.py new file mode 100644 index 0000000000..e065b65ca1 --- /dev/null +++ b/tests/utils/test_span_utils.py @@ -0,0 +1,157 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from rubrix.utils.span_utils import SpanUtils + + +def test_init(): + text = "test this." + tokens = ["test", "this", "."] + + span_utils = SpanUtils(text, tokens) + + assert span_utils.text is text + assert span_utils.tokens is tokens + + assert span_utils.token_to_char_idx == {0: (0, 4), 1: (5, 9), 2: (9, 10)} + assert span_utils.char_to_token_idx == { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 5: 1, + 6: 1, + 7: 1, + 8: 1, + 9: 2, + } + + assert span_utils._start_to_token_idx == {0: 0, 5: 1, 9: 2} + assert span_utils._end_to_token_idx == {4: 0, 9: 1, 10: 2} + + +def test_init_value_error(): + with pytest.raises( + ValueError, match="Token 'ValueError' not found in text: test error" + ): + SpanUtils(text="test error", tokens=["test", "ValueError"]) + + +def test_validate(): + span_utils = SpanUtils("test this.", ["test", "this", "."]) + assert span_utils.validate([("mock", 5, 10)]) is None + + +def test_validate_not_valid_spans(): + span_utils = SpanUtils("test this.", ["test", "this", "."]) + with pytest.raises( + ValueError, match="Following entity spans are not valid: \[\('mock', 2, 1\)\]\n" + ): + span_utils.validate([("mock", 2, 1)]) + + +def test_validate_misaligned_spans(): + span_utils = SpanUtils("test this.", ["test", "this", "."]) + with pytest.raises( + ValueError, + match="Following entity spans are not aligned with provided tokenization\n" + r"Spans:\n- \[test \] defined in test this.\n" + r"Tokens:\n\['test', 'this', '.'\]", + ): + span_utils.validate([("mock", 0, 5)]) + + +def test_validate_not_valid_and_misaligned_spans(): + span_utils = SpanUtils("test this.", ["test", "this", "."]) + with pytest.raises( + ValueError, + match=r"Following entity spans are not valid: \[\('mock', 2, 1\)\]\n" + "Following entity spans are not aligned with provided tokenization\n" + r"Spans:\n- \[test \] defined in test this.\n" + r"Tokens:\n\['test', 'this', '.'\]", + ): + span_utils.validate([("mock", 2, 1), ("mock", 0, 5)]) + + +@pytest.mark.parametrize( + "spans, expected", + [ + ([("mock", -1, 4), ("mock", 20, 22)], [("mock", 0, 4), ("mock", 20, 21)]), + ([("mock", 0, 5), ("mock", 4, 9)], [("mock", 0, 4), ("mock", 5, 9)]), + ([("mock", 10, 15), ("mock", 11, 16)], [("mock", 11, 15), ("mock", 11, 15)]), + ], +) +def test_correct(spans, expected): + text = "test this \nnext\ttext." + tokens = ["test", "this", "\n", "next", "\t", "text", "."] + span_utils = SpanUtils(text, tokens) + + assert span_utils.correct(spans) == expected + + +def test_to_tags_overlapping_spans(): + span_utils = SpanUtils("test this.", ["test", "this", "."]) + with pytest.raises(ValueError, match="IOB tags cannot handle overlapping spans!"): + span_utils.to_tags([("mock", 0, 4), ("mock", 0, 9)]) + + +def test_to_tags(): + span_utils = SpanUtils("test this.", ["test", "this", "."]) + tags = span_utils.to_tags([("mock", 0, 9)]) + assert tags == ["B-mock", "I-mock", "O"] + + +def test_from_tags_wrong_length(): + span_utils = SpanUtils("test this.", ["test", "this", "."]) + with pytest.raises( + ValueError, + match="The list of tags must have the same length as the list of tokens!", + ): + span_utils.from_tags(["mock", "mock"]) + + +def test_from_tags_not_valid_format(): + span_utils = SpanUtils("test this.", ["test", "this", "."]) + with pytest.raises(ValueError, match="Tags are not in the IOB or BILOU format!"): + span_utils.from_tags(["mock", "mock", "mock"]) + + +@pytest.mark.parametrize( + "tags,expected", + [ + (["B-mock", "O", "O"], [("mock", 0, 4)]), + (["I-mock", "O", "O"], [("mock", 0, 4)]), + (["U-mock", "O", "O"], [("mock", 0, 4)]), + (["L-mock", "O", "O"], [("mock", 0, 4)]), + (["B-mock", "I-mock", "O"], [("mock", 0, 9)]), + (["I-mock", "I-mock", "O"], [("mock", 0, 9)]), + (["B-mock", "L-mock", "O"], [("mock", 0, 9)]), + (["I-mock", "L-mock", "O"], [("mock", 0, 9)]), + (["B-mock", "I-mock", "I-mock"], [("mock", 0, 10)]), + (["I-mock", "I-mock", "I-mock"], [("mock", 0, 10)]), + (["B-mock", "I-mock", "L-mock"], [("mock", 0, 10)]), + (["B-mock", "L-mock", "L-mock"], [("mock", 0, 9), ("mock", 9, 10)]), + (["U-mock", "U-mock", "O"], [("mock", 0, 4), ("mock", 5, 9)]), + (["U-mock", "I-mock", "O"], [("mock", 0, 4), ("mock", 5, 9)]), + (["B-mock", "B-mock", "O"], [("mock", 0, 4), ("mock", 5, 9)]), + (["U-mock", "B-mock", "O"], [("mock", 0, 4), ("mock", 5, 9)]), + (["I-mock", "B-mock", "O"], [("mock", 0, 4), ("mock", 5, 9)]), + (["L-mock", "B-mock", "O"], [("mock", 0, 4), ("mock", 5, 9)]), + ], +) +def test_from_tags(tags, expected): + span_utils = SpanUtils("test this.", ["test", "this", "."]) + assert span_utils.from_tags(tags) == expected diff --git a/tests/test_utils.py b/tests/utils/test_utils.py similarity index 94% rename from tests/test_utils.py rename to tests/utils/test_utils.py index 2410797e44..65fcce34ae 100644 --- a/tests/test_utils.py +++ b/tests/utils/test_utils.py @@ -14,7 +14,7 @@ # limitations under the License. import pytest -from rubrix.utils import _LazyRubrixModule +from rubrix.utils import LazyRubrixModule def test_lazy_rubrix_module(monkeypatch): @@ -23,7 +23,7 @@ def mock_import_module(name, package): monkeypatch.setattr("importlib.import_module", mock_import_module) - lazy_module = _LazyRubrixModule( + lazy_module = LazyRubrixModule( name="rb_mock", module_file="rb_mock_file", import_structure={"mock_module": ["title"]}, @@ -57,7 +57,7 @@ def mock_import_module(*args, **kwargs): monkeypatch.setattr("importlib.import_module", mock_import_module) - lazy_module = _LazyRubrixModule( + lazy_module = LazyRubrixModule( name="rb_mock", module_file=__file__, import_structure={"mock_module": ["title"]},