Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(#1225): create iob tags from record spans #1226

Merged
merged 8 commits into from Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/rubrix/client/datasets.py
Expand Up @@ -597,7 +597,7 @@ def entities_to_tuple(entities):
return cls(records)

@classmethod
def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTextClassification":
def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTokenClassification":
return cls(
[TokenClassificationRecord(**row) for row in dataframe.to_dict("records")]
)
Expand Down
124 changes: 121 additions & 3 deletions src/rubrix/client/models.py
Expand Up @@ -19,10 +19,12 @@

import datetime
import warnings
from collections import defaultdict
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd
from pydantic import BaseModel, Field, root_validator, validator
from pydantic import BaseModel, Field, PrivateAttr, root_validator, validator

from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.commons.helpers import limit_value_length
Expand Down Expand Up @@ -228,8 +230,8 @@ class TokenClassificationRecord(_Validators):
... )
"""

text: str
tokens: List[str]
text: str = Field(min_length=1)
tokens: Union[List[str], Tuple[str, ...]]

prediction: Optional[
List[Union[Tuple[str, int, int], Tuple[str, int, int, float]]]
Expand All @@ -246,6 +248,23 @@ class TokenClassificationRecord(_Validators):
metrics: Optional[Dict[str, Any]] = None
search_keywords: Optional[List[str]] = None

__chars2tokens__: Dict[int, int] = PrivateAttr(default=None)
__tokens2chars__: Dict[int, Tuple[int, int]] = PrivateAttr(default=None)

def __setattr__(self, name: str, value: Any):
"""Make text and tokens immutable"""
if name in ["text", "tokens"]:
raise AttributeError(f"You cannot assign a new value to `{name}`")
super().__setattr__(name, value)

@validator("tokens", pre=True)
def _normalize_tokens(cls, value):
if isinstance(value, list):
value = tuple(value)

assert len(value) > 0, "At least one token should be provided"
return value

@validator("prediction")
def add_default_score(
cls,
Expand All @@ -261,6 +280,105 @@ def add_default_score(
for pred in prediction
]

@staticmethod
def __build_indices_map__(
text: str, tokens: Tuple[str, ...]
) -> Tuple[Dict[int, int], Dict[int, Tuple[int, int]]]:
"""
Build the indices mapping between text characters and tokens where belongs to,
and vice versa.

chars2tokens index contains is the token idx where i char is contained (if any).

Out-of-token characters won't be included in this map,
so access should be using ``chars2tokens_map.get(i)``
instead of ``chars2tokens_map[i]``.

"""

def chars2tokens_index(text_, tokens_):
chars_map = {}
current_token = 0
current_token_char_start = 0
for idx, char in enumerate(text_):
relative_idx = idx - current_token_char_start
if (
relative_idx < len(tokens_[current_token])
and char == tokens_[current_token][relative_idx]
):
chars_map[idx] = current_token
elif (
current_token + 1 < len(tokens_)
and relative_idx >= len(tokens_[current_token])
and char == tokens_[current_token + 1][0]
):
current_token += 1
current_token_char_start += relative_idx
chars_map[idx] = current_token

return chars_map

def tokens2chars_index(
chars2tokens: Dict[int, int]
) -> Dict[int, Tuple[int, int]]:
tokens2chars_map = defaultdict(list)
for c, t in chars2tokens.items():
tokens2chars_map[t].append(c)

return {
token_idx: (min(chars), max(chars))
for token_idx, chars in tokens2chars_map.items()
}

chars2tokens_idx = chars2tokens_index(text_=text, tokens_=tokens)
return chars2tokens_idx, tokens2chars_index(chars2tokens_idx)

def char_id2token_id(self, char_idx: int) -> Optional[int]:
"""
Given a character id, returns the token id it belongs to.
``None`` otherwise
"""

if self.__chars2tokens__ is None:
self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__(
self.text, tuple(self.tokens)
)
return self.__chars2tokens__.get(char_idx)

def token_span(self, token_idx: int) -> Tuple[int, int]:
"""
Given a token id, returns the start and end characters.
Raises an ``IndexError`` if token id is out of tokens list indices
"""
if self.__tokens2chars__ is None:
self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__(
self.text, tuple(self.tokens)
)
if token_idx not in self.__tokens2chars__:
raise IndexError(f"Token id {token_idx} out of bounds")
return self.__tokens2chars__[token_idx]

def spans2iob(
self, spans: Optional[List[Tuple[str, int, int]]] = None
) -> Optional[List[str]]:
"""Build the iob tags sequence for a list of spans annoations"""

if spans is None:
return None

tags = ["O"] * len(self.tokens)
for label, start, end in spans:
token_start = self.char_id2token_id(start)
token_end = self.char_id2token_id(end - 1)
assert (
token_start is not None and token_end is not None
), "Provided spans are missaligned at token level"
tags[token_start] = f"B-{label}"
for idx in range(token_start + 1, token_end + 1):
tags[idx] = f"I-{label}"

return tags


class Text2TextRecord(_Validators):
"""Record for a text to text task
Expand Down
5 changes: 4 additions & 1 deletion tests/client/test_asgi.py
Expand Up @@ -72,10 +72,13 @@ def __call__(self, records, name: str, **kwargs):
],
)

time.sleep(0.2)
assert mock_log.was_called
time.sleep(0.200)

mock_log.was_called = False
mock.get("/another/predict/route")

time.sleep(0.2)
assert not mock_log.was_called


Expand Down
30 changes: 24 additions & 6 deletions tests/client/test_models.py
Expand Up @@ -63,20 +63,38 @@ def test_text_classification_input_string():


@pytest.mark.parametrize(
("annotation", "status", "expected_status"),
("annotation", "status", "expected_status", "expected_iob"),
[
(None, None, "Default"),
([("test", 0, 5)], None, "Validated"),
(None, "Discarded", "Discarded"),
([("test", 0, 5)], "Discarded", "Discarded"),
(None, None, "Default", None),
([("test", 0, 4)], None, "Validated", ["B-test", "O"]),
(None, "Discarded", "Discarded", None),
([("test", 0, 9)], "Discarded", "Discarded", ["B-test", "I-test"]),
],
)
def test_token_classification_record(annotation, status, expected_status):
def test_token_classification_record(annotation, status, expected_status, expected_iob):
"""Just testing its dynamic defaults"""
record = TokenClassificationRecord(
text="test text", tokens=["test", "text"], annotation=annotation, status=status
)
assert record.status == expected_status
assert record.spans2iob(record.annotation) == expected_iob


def test_token_classification_with_mutation():
text_a = "The text"
text_b = "Another text sample here !!!"

record = TokenClassificationRecord(
text=text_a, tokens=text_a.split(" "), annotation=[]
)
assert record.spans2iob(record.annotation) == ["O"] * len(text_a.split(" "))

with pytest.raises(AttributeError, match="You cannot assign a new value to `text`"):
record.text = text_b
with pytest.raises(
AttributeError, match="You cannot assign a new value to `tokens`"
):
record.tokens = text_b.split(" ")


@pytest.mark.parametrize(
Expand Down
Expand Up @@ -23,7 +23,7 @@ def test_log_with_empty_tokens_list(mocked_client):
rubrix.delete(dataset)
with pytest.raises(
Exception,
match="ensure this value has at least 1 items",
match="At least one token should be provided",
):
rubrix.log(
TokenClassificationRecord(id=0, text=text, tokens=[]),
Expand Down