Skip to content

Commit

Permalink
feat(token-class): adjust token spans spaces (#1599)
Browse files Browse the repository at this point in the history
* test: add tests

(cherry picked from commit b42bb6d)
  • Loading branch information
frascuchon committed Jul 8, 2022
1 parent 65747ab commit 0fb3576
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/rubrix/server/apis/v0/models/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,21 @@ def check_annotation(
annotation: Optional[TokenClassificationAnnotation],
):
"""Validates entities in terms of offset spans"""

def adjust_span_bounds(start, end):
if start < 0:
start = 0
if entity.end > len(self.text):
end = len(self.text)
while start <= len(self.text) and not self.text[start].strip():
start += 1
while not self.text[end - 1].strip():
end -= 1
return start, end

if annotation:
for entity in annotation.entities:
entity.start, entity.end = adjust_span_bounds(entity.start, entity.end)
mention = self.text[entity.start : entity.end]
assert len(mention) > 0, f"Empty offset defined for entity {entity}"

Expand Down
34 changes: 34 additions & 0 deletions tests/server/token_classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,39 @@ def test_annotated_without_entities():
assert record.predicted == PredictionStatus.KO


def test_adjust_spans():

text = "A text with some empty spaces that could bring not cleany annotated spans"

record = TokenClassificationRecord(
text=text,
tokens=text.split(),
prediction=TokenClassificationAnnotation(
agent="pred.test",
entities=[
EntitySpan(start=-3, end=2, label="DET"),
EntitySpan(start=24, end=36, label="NAME"),
],
),
annotation=TokenClassificationAnnotation(
agent="test",
entities=[
EntitySpan(start=48, end=61, label="VERB"),
EntitySpan(start=68, end=100, label="DET"),
],
),
)

assert record.prediction.entities == [
EntitySpan(start=0, end=1, label="DET"),
EntitySpan(start=28, end=34, label="NAME"),
]

assert record.annotation.entities == [
EntitySpan(start=50, end=60, label="VERB"),
EntitySpan(start=70, end=85, label="DET"),
]

def test_whitespace_in_tokens():
from spacy import load

Expand All @@ -261,3 +294,4 @@ def test_whitespace_in_tokens():
record = CreationTokenClassificationRecord.parse_obj(record)
assert record
assert record.tokens == ["every", "four", "(", "4", ")", " "]

0 comments on commit 0fb3576

Please sign in to comment.