Skip to content

Commit

Permalink
refactor(token-class): discarding first space after token
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Jul 4, 2022
1 parent bd729ed commit e367fbd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
10 changes: 5 additions & 5 deletions src/rubrix/server/apis/v0/models/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,6 @@ def token_span(self, token_idx: int) -> Tuple[int, int]:
raise IndexError(f"Token id {token_idx} out of bounds")
return self.__tokens2chars__[token_idx]

@validator("tokens")
def remove_empty_tokens(cls, tokens: List[str]):
# TODO: maybe launch a warning about changing provided tokens
return [t for t in tokens if t.replace(" ", "")]

@validator("text")
def check_text_content(cls, text: str):
assert text and text.strip(), "No text or empty text provided"
Expand All @@ -160,10 +155,15 @@ def __build_indices_map__(
"""

def chars2tokens_index():
def is_space_after_token(char, idx: int, chars_map) -> str:
return char == " " and idx - 1 in chars_map

chars_map = {}
current_token = 0
current_token_char_start = 0
for idx, char in enumerate(self.text):
if is_space_after_token(char, idx, chars_map):
continue
relative_idx = idx - current_token_char_start
if (
relative_idx < len(self.tokens[current_token])
Expand Down
16 changes: 9 additions & 7 deletions tests/server/token_classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ def test_entities_with_spaces():


def test_model_dict():
text = "This is a great space"
tokens = ["This", "is", " ", "a", " ", "great", " ", "space"]
record = TokenClassificationRecord(
id="1",
text="This is a great space",
tokens=["This", "is", " ", "a", " ", "great", " ", "space"],
text=text,
tokens=tokens,
prediction=TokenClassificationAnnotation(
agent="test",
entities=[
Expand All @@ -100,10 +102,10 @@ def test_model_dict():
"agent": "test",
"entities": [{"end": 24, "label": "test", "score": 1.0, "start": 9}],
},
"raw_text": "This is a great space",
"raw_text": text,
"text": text,
"tokens": tokens,
"status": "Default",
"text": "This is a great space",
"tokens": ["This", "is", "a", "great", "space"],
}


Expand Down Expand Up @@ -244,7 +246,7 @@ def test_whitespace_in_tokens():
from spacy import load

nlp = load("en_core_web_sm")
text = "every four (4)"
text = "every four (4) "
doc = nlp(text)

record = {
Expand All @@ -258,4 +260,4 @@ def test_whitespace_in_tokens():

record = CreationTokenClassificationRecord.parse_obj(record)
assert record
print(record.dict())
assert record.tokens == ["every", "four", "(", "4", ")", " "]

0 comments on commit e367fbd

Please sign in to comment.