diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5cf89e619c..ee7b8c47b2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,6 +69,8 @@ jobs: pytest --cov=rubrix --cov-report=xml pip install "spacy<3.0" && python -m spacy download en_core_web_sm pytest tests/monitoring/test_spacy_monitoring.py + pip install "cleanlab<2.0" + pytest tests/labeling/text_classification/test_label_errors.py - name: Upload Coverage to Codecov 📦 uses: codecov/codecov-action@v1 diff --git a/environment_dev.yml b/environment_dev.yml index d70ed0e6eb..594d020702 100644 --- a/environment_dev.yml +++ b/environment_dev.yml @@ -30,7 +30,7 @@ dependencies: # code formatting - pre-commit==2.15.0 # extra test dependencies - - cleanlab<2.0.0 + - cleanlab - datasets>1.17.0 - huggingface_hub != 0.5.0 # some backward comp. problems introduced in 0.5.0 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.1.0/en_core_web_sm-3.1.0.tar.gz diff --git a/src/rubrix/labeling/text_classification/label_errors.py b/src/rubrix/labeling/text_classification/label_errors.py index 0877df1c5b..fe2d3c19fc 100644 --- a/src/rubrix/labeling/text_classification/label_errors.py +++ b/src/rubrix/labeling/text_classification/label_errors.py @@ -17,6 +17,7 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np +from pkg_resources import parse_version from rubrix.client.datasets import DatasetForTextClassification from rubrix.client.models import TextClassificationRecord @@ -59,7 +60,8 @@ def find_label_errors( metadata_key: The key added to the record's metadata that holds the order, if ``sort_by`` is not "none". n_jobs : Number of processing threads used by multiprocessing. If None, uses the number of threads on your CPU. Defaults to 1, which removes parallel processing. - **kwargs: Passed on to `cleanlab.pruning.get_noise_indices` + **kwargs: Passed on to `cleanlab.pruning.get_noise_indices` (cleanlab < 2.0) + or `cleanlab.filter.find_label_issues` (cleanlab >= 2.0) Returns: A list of records containing potential annotation/label errors @@ -82,7 +84,10 @@ def find_label_errors( "You can install 'cleanlab' with the command: `pip install cleanlab`" ) else: - from cleanlab.pruning import get_noise_indices + if parse_version(cleanlab.__version__) < parse_version("2.0"): + from cleanlab.pruning import get_noise_indices as find_label_issues + else: + from cleanlab.filter import find_label_issues if isinstance(sort_by, str): sort_by = SortBy(sort_by) @@ -95,12 +100,12 @@ def find_label_errors( ) # check and update kwargs for get_noise_indices - _check_and_update_kwargs(records[0], sort_by, kwargs) + _check_and_update_kwargs(cleanlab.__version__, records[0], sort_by, kwargs) # construct "noisy" label vector and probability matrix of the predictions s, psx = _construct_s_and_psx(records) - indices = get_noise_indices(s, psx, n_jobs=n_jobs, **kwargs) + indices = find_label_issues(s, psx, n_jobs=n_jobs, **kwargs) records_with_label_errors = np.array(records)[indices].tolist() @@ -113,11 +118,12 @@ def find_label_errors( def _check_and_update_kwargs( - record: TextClassificationRecord, sort_by: SortBy, kwargs: Dict + version: str, record: TextClassificationRecord, sort_by: SortBy, kwargs: Dict ): """Helper function to check and update the kwargs passed on to cleanlab's `get_noise_indices`. Args: + version: version of cleanlab record: One of the records passed in the `find_label_error` function. sort_by: The sorting policy. kwargs: The passed on kwargs. @@ -125,16 +131,6 @@ def _check_and_update_kwargs( Raises: ValueError: If not supported kwargs ('sorted_index_method') are passed on. """ - if "sorted_index_method" in kwargs: - raise ValueError( - "The 'sorted_index_method' kwarg is not supported, please use 'sort_by' instead." - ) - kwargs["sorted_index_method"] = "normalized_margin" - if sort_by is SortBy.PREDICTION: - kwargs["sorted_index_method"] = "prob_given_label" - elif sort_by is SortBy.NONE: - kwargs["sorted_index_method"] = None - if "multi_label" in kwargs: _LOGGER.warning( "You provided the kwarg 'multi_label', but it is determined automatically. " @@ -142,6 +138,34 @@ def _check_and_update_kwargs( ) kwargs["multi_label"] = record.multi_label + if parse_version(version) < parse_version("2.0"): + if "sorted_index_method" in kwargs: + raise ValueError( + "The 'sorted_index_method' kwarg is not supported, please use 'sort_by' instead." + ) + kwargs["sorted_index_method"] = "normalized_margin" + if sort_by is SortBy.PREDICTION: + kwargs["sorted_index_method"] = "prob_given_label" + elif sort_by is SortBy.NONE: + kwargs["sorted_index_method"] = None + else: + if "return_indices_ranked_by" in kwargs: + raise ValueError( + "The 'return_indices_ranked_by' kwarg is not supported, please use 'sort_by' instead." + ) + kwargs["return_indices_ranked_by"] = "normalized_margin" + if sort_by is SortBy.PREDICTION: + kwargs["return_indices_ranked_by"] = "self_confidence" + elif sort_by is SortBy.NONE: + kwargs["return_indices_ranked_by"] = None + # TODO: Remove this once https://github.com/cleanlab/cleanlab/issues/243 is solved + elif kwargs["multi_label"]: + _LOGGER.warning( + "With cleanlab v2 and multi-label records there is an issue sorting records by 'likelihood' " + "(https://github.com/cleanlab/cleanlab/issues/243). We will sort by 'prediction' instead." + ) + kwargs["return_indices_ranked_by"] = "self_confidence" + def _construct_s_and_psx( records: List[TextClassificationRecord], diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py index 78b4479330..2900c1398b 100644 --- a/tests/labeling/text_classification/test_label_errors.py +++ b/tests/labeling/text_classification/test_label_errors.py @@ -14,7 +14,9 @@ # limitations under the License. import sys +import cleanlab import pytest +from pkg_resources import parse_version import rubrix as rb from rubrix.labeling.text_classification import find_label_errors @@ -77,7 +79,10 @@ def test_no_records(): def test_multi_label_warning(caplog): record = rb.TextClassificationRecord( - text="test", prediction=[("mock", 0.0)], annotation="mock" + text="test", + prediction=[("mock", 0.0), ("mock2", 0.0)], + annotation=["mock", "mock2"], + multi_label=True, ) find_label_errors([record], multi_label="True") assert ( @@ -89,20 +94,32 @@ def test_multi_label_warning(caplog): @pytest.mark.parametrize( "sort_by,expected", [ - ("likelihood", "normalized_margin"), - ("prediction", "prob_given_label"), - ("none", None), + ("likelihood", ("normalized_margin", "normalized_margin")), + ("prediction", ("prob_given_label", "self_confidence")), + ("none", (None, None)), ], ) def test_sort_by(monkeypatch, sort_by, expected): - def mock_get_noise_indices(*args, **kwargs): - assert kwargs["sorted_index_method"] == expected - return [] + if parse_version(cleanlab.__version__) < parse_version("2.0"): - monkeypatch.setattr( - "cleanlab.pruning.get_noise_indices", - mock_get_noise_indices, - ) + def mock_get_noise_indices(*args, **kwargs): + assert kwargs["sorted_index_method"] == expected[0] + return [] + + monkeypatch.setattr( + "cleanlab.pruning.get_noise_indices", + mock_get_noise_indices, + ) + else: + + def mock_find_label_issues(*args, **kwargs): + assert kwargs["return_indices_ranked_by"] == expected[1] + return [] + + monkeypatch.setattr( + "cleanlab.filter.find_label_issues", + mock_find_label_issues, + ) record = rb.TextClassificationRecord( inputs="mock", prediction=[("mock", 0.1)], annotation="mock" @@ -113,25 +130,50 @@ def mock_get_noise_indices(*args, **kwargs): def test_kwargs(monkeypatch, records): is_multi_label = records[0].multi_label - def mock_get_noise_indices(s, psx, n_jobs, **kwargs): - assert kwargs == { - "multi_label": is_multi_label, - "sorted_index_method": "normalized_margin", - "mock": "mock", - } - return [] - - monkeypatch.setattr( - "cleanlab.pruning.get_noise_indices", - mock_get_noise_indices, - ) + if parse_version(cleanlab.__version__) < parse_version("2.0"): - with pytest.raises( - ValueError, match="'sorted_index_method' kwarg is not supported" - ): - find_label_errors(records=records, sorted_index_method="mock") + def mock_get_noise_indices(s, psx, n_jobs, **kwargs): + assert kwargs == { + "mock": "mock", + "multi_label": is_multi_label, + "sorted_index_method": "normalized_margin", + } + return [] + + monkeypatch.setattr( + "cleanlab.pruning.get_noise_indices", + mock_get_noise_indices, + ) + + with pytest.raises( + ValueError, match="'sorted_index_method' kwarg is not supported" + ): + find_label_errors(records=records, sorted_index_method="mock") + + find_label_errors(records=records, mock="mock") + else: + + def mock_find_label_issues(s, psx, n_jobs, **kwargs): + assert kwargs == { + "mock": "mock", + "multi_label": is_multi_label, + "return_indices_ranked_by": "normalized_margin" + if not is_multi_label + else "self_confidence", + } + return [] + + monkeypatch.setattr( + "cleanlab.filter.find_label_issues", + mock_find_label_issues, + ) + + with pytest.raises( + ValueError, match="'return_indices_ranked_by' kwarg is not supported" + ): + find_label_errors(records=records, return_indices_ranked_by="mock") - find_label_errors(records=records, mock="mock") + find_label_errors(records=records, mock="mock") def test_construct_s_and_psx(records):