From ecbdd78bd526b36b97a65e3fc04fb180e36a8ec3 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 7 Mar 2022 22:21:51 +0100 Subject: [PATCH] feat(#1225): create iob tags from record spans (#1226) * feat(#1225): create iob tags from record spans * test: add tests * refactor: dynamic tokens map with text/tokens mutability * chore: naming * feat: make text and tokens immutable * chore: adapt to inmutable text and tokens * test: fix tests * test: fixing tests Co-authored-by: dcfidalgo (cherry picked from commit 07b895d501d34eb338bc8a3b33a3ac089606e989) --- src/rubrix/client/datasets.py | 2 +- src/rubrix/client/models.py | 2 ++ tests/client/test_asgi.py | 5 ++++- tests/client/test_models.py | 12 ++++++------ .../test_log_for_token_classification.py | 2 +- 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/rubrix/client/datasets.py b/src/rubrix/client/datasets.py index 212da63cdc..4f3f768263 100644 --- a/src/rubrix/client/datasets.py +++ b/src/rubrix/client/datasets.py @@ -877,7 +877,7 @@ def parse_tags_from_example(example): return dataset.map(parse_tags_from_example) @classmethod - def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTextClassification": + def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTokenClassification": return cls( [TokenClassificationRecord(**row) for row in dataframe.to_dict("records")] ) diff --git a/src/rubrix/client/models.py b/src/rubrix/client/models.py index 1544ec6db6..440358be9a 100644 --- a/src/rubrix/client/models.py +++ b/src/rubrix/client/models.py @@ -21,6 +21,7 @@ import logging import warnings from collections import defaultdict +from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple, Union import pandas as pd @@ -381,6 +382,7 @@ def chars2tokens_index(text_, tokens_): current_token += 1 current_token_char_start += relative_idx chars_map[idx] = current_token + return chars_map def tokens2chars_index( diff --git a/tests/client/test_asgi.py b/tests/client/test_asgi.py index e88ae90169..6b157b0bdc 100644 --- a/tests/client/test_asgi.py +++ b/tests/client/test_asgi.py @@ -72,10 +72,13 @@ def __call__(self, records, name: str, **kwargs): ], ) + time.sleep(0.2) assert mock_log.was_called - time.sleep(0.200) + mock_log.was_called = False mock.get("/another/predict/route") + + time.sleep(0.2) assert not mock_log.was_called diff --git a/tests/client/test_models.py b/tests/client/test_models.py index 9ba44a58b1..e637756421 100644 --- a/tests/client/test_models.py +++ b/tests/client/test_models.py @@ -63,15 +63,15 @@ def test_text_classification_input_string(): @pytest.mark.parametrize( - ("annotation", "status", "expected_status"), + ("annotation", "status", "expected_status", "expected_iob"), [ - (None, None, "Default"), - ([("test", 0, 5)], None, "Validated"), - (None, "Discarded", "Discarded"), - ([("test", 0, 5)], "Discarded", "Discarded"), + (None, None, "Default", None), + ([("test", 0, 4)], None, "Validated", ["B-test", "O"]), + (None, "Discarded", "Discarded", None), + ([("test", 0, 9)], "Discarded", "Discarded", ["B-test", "I-test"]), ], ) -def test_token_classification_record(annotation, status, expected_status): +def test_token_classification_record(annotation, status, expected_status, expected_iob): """Just testing its dynamic defaults""" record = TokenClassificationRecord( text="test text", tokens=["test", "text"], annotation=annotation, status=status diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py index add2a7e87f..ad50ff220d 100644 --- a/tests/functional_tests/test_log_for_token_classification.py +++ b/tests/functional_tests/test_log_for_token_classification.py @@ -23,7 +23,7 @@ def test_log_with_empty_tokens_list(mocked_client): rubrix.delete(dataset) with pytest.raises( Exception, - match="ensure this value has at least 1 items", + match="At least one token should be provided", ): rubrix.log( TokenClassificationRecord(id=0, text=text, tokens=[]),