Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(#1428): support cleanlab v2 #1436

Merged
merged 6 commits into from Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Expand Up @@ -69,6 +69,8 @@ jobs:
pytest --cov=rubrix --cov-report=xml
pip install "spacy<3.0" && python -m spacy download en_core_web_sm
pytest tests/monitoring/test_spacy_monitoring.py
pip install "cleanlab<2.0"
pytest tests/labeling/text_classification/test_label_errors.py

- name: Upload Coverage to Codecov 📦
uses: codecov/codecov-action@v1
Expand Down
2 changes: 1 addition & 1 deletion environment_dev.yml
Expand Up @@ -30,7 +30,7 @@ dependencies:
# code formatting
- pre-commit==2.15.0
# extra test dependencies
- cleanlab<2.0.0
- cleanlab
- datasets>1.17.0
- huggingface_hub != 0.5.0 # some backward comp. problems introduced in 0.5.0
- https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.1.0/en_core_web_sm-3.1.0.tar.gz
Expand Down
54 changes: 39 additions & 15 deletions src/rubrix/labeling/text_classification/label_errors.py
Expand Up @@ -17,6 +17,7 @@
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from pkg_resources import parse_version

from rubrix.client.datasets import DatasetForTextClassification
from rubrix.client.models import TextClassificationRecord
Expand Down Expand Up @@ -59,7 +60,8 @@ def find_label_errors(
metadata_key: The key added to the record's metadata that holds the order, if ``sort_by`` is not "none".
n_jobs : Number of processing threads used by multiprocessing. If None, uses the number of threads
on your CPU. Defaults to 1, which removes parallel processing.
**kwargs: Passed on to `cleanlab.pruning.get_noise_indices`
**kwargs: Passed on to `cleanlab.pruning.get_noise_indices` (cleanlab < 2.0)
or `cleanlab.filter.find_label_issues` (cleanlab >= 2.0)

Returns:
A list of records containing potential annotation/label errors
Expand All @@ -82,7 +84,10 @@ def find_label_errors(
"You can install 'cleanlab' with the command: `pip install cleanlab`"
)
else:
from cleanlab.pruning import get_noise_indices
if parse_version(cleanlab.__version__) < parse_version("2.0"):
from cleanlab.pruning import get_noise_indices as find_label_issues
else:
from cleanlab.filter import find_label_issues

if isinstance(sort_by, str):
sort_by = SortBy(sort_by)
Expand All @@ -95,12 +100,12 @@ def find_label_errors(
)

# check and update kwargs for get_noise_indices
_check_and_update_kwargs(records[0], sort_by, kwargs)
_check_and_update_kwargs(cleanlab.__version__, records[0], sort_by, kwargs)

# construct "noisy" label vector and probability matrix of the predictions
s, psx = _construct_s_and_psx(records)

indices = get_noise_indices(s, psx, n_jobs=n_jobs, **kwargs)
indices = find_label_issues(s, psx, n_jobs=n_jobs, **kwargs)

records_with_label_errors = np.array(records)[indices].tolist()

Expand All @@ -113,35 +118,54 @@ def find_label_errors(


def _check_and_update_kwargs(
record: TextClassificationRecord, sort_by: SortBy, kwargs: Dict
version: str, record: TextClassificationRecord, sort_by: SortBy, kwargs: Dict
):
"""Helper function to check and update the kwargs passed on to cleanlab's `get_noise_indices`.

Args:
version: version of cleanlab
record: One of the records passed in the `find_label_error` function.
sort_by: The sorting policy.
kwargs: The passed on kwargs.

Raises:
ValueError: If not supported kwargs ('sorted_index_method') are passed on.
"""
if "sorted_index_method" in kwargs:
raise ValueError(
"The 'sorted_index_method' kwarg is not supported, please use 'sort_by' instead."
)
kwargs["sorted_index_method"] = "normalized_margin"
if sort_by is SortBy.PREDICTION:
kwargs["sorted_index_method"] = "prob_given_label"
elif sort_by is SortBy.NONE:
kwargs["sorted_index_method"] = None

if "multi_label" in kwargs:
_LOGGER.warning(
"You provided the kwarg 'multi_label', but it is determined automatically. "
f"We will set it to '{record.multi_label}'."
)
kwargs["multi_label"] = record.multi_label

if parse_version(version) < parse_version("2.0"):
if "sorted_index_method" in kwargs:
raise ValueError(
"The 'sorted_index_method' kwarg is not supported, please use 'sort_by' instead."
)
kwargs["sorted_index_method"] = "normalized_margin"
if sort_by is SortBy.PREDICTION:
kwargs["sorted_index_method"] = "prob_given_label"
elif sort_by is SortBy.NONE:
kwargs["sorted_index_method"] = None
else:
if "return_indices_ranked_by" in kwargs:
raise ValueError(
"The 'return_indices_ranked_by' kwarg is not supported, please use 'sort_by' instead."
)
kwargs["return_indices_ranked_by"] = "normalized_margin"
if sort_by is SortBy.PREDICTION:
kwargs["return_indices_ranked_by"] = "self_confidence"
elif sort_by is SortBy.NONE:
kwargs["return_indices_ranked_by"] = None
# TODO: Remove this once https://github.com/cleanlab/cleanlab/issues/243 is solved
elif kwargs["multi_label"]:
_LOGGER.warning(
"With cleanlab v2 and multi-label records there is an issue sorting records by 'likelihood' "
"(https://github.com/cleanlab/cleanlab/issues/243). We will sort by 'prediction' instead."
)
kwargs["return_indices_ranked_by"] = "self_confidence"


def _construct_s_and_psx(
records: List[TextClassificationRecord],
Expand Down
98 changes: 70 additions & 28 deletions tests/labeling/text_classification/test_label_errors.py
Expand Up @@ -14,7 +14,9 @@
# limitations under the License.
import sys

import cleanlab
import pytest
from pkg_resources import parse_version

import rubrix as rb
from rubrix.labeling.text_classification import find_label_errors
Expand Down Expand Up @@ -77,7 +79,10 @@ def test_no_records():

def test_multi_label_warning(caplog):
record = rb.TextClassificationRecord(
text="test", prediction=[("mock", 0.0)], annotation="mock"
text="test",
prediction=[("mock", 0.0), ("mock2", 0.0)],
annotation=["mock", "mock2"],
multi_label=True,
)
find_label_errors([record], multi_label="True")
assert (
Expand All @@ -89,20 +94,32 @@ def test_multi_label_warning(caplog):
@pytest.mark.parametrize(
"sort_by,expected",
[
("likelihood", "normalized_margin"),
("prediction", "prob_given_label"),
("none", None),
("likelihood", ("normalized_margin", "normalized_margin")),
("prediction", ("prob_given_label", "self_confidence")),
("none", (None, None)),
],
)
def test_sort_by(monkeypatch, sort_by, expected):
def mock_get_noise_indices(*args, **kwargs):
assert kwargs["sorted_index_method"] == expected
return []
if parse_version(cleanlab.__version__) < parse_version("2.0"):

monkeypatch.setattr(
"cleanlab.pruning.get_noise_indices",
mock_get_noise_indices,
)
def mock_get_noise_indices(*args, **kwargs):
assert kwargs["sorted_index_method"] == expected[0]
return []

monkeypatch.setattr(
"cleanlab.pruning.get_noise_indices",
mock_get_noise_indices,
)
else:

def mock_find_label_issues(*args, **kwargs):
assert kwargs["return_indices_ranked_by"] == expected[1]
return []

monkeypatch.setattr(
"cleanlab.filter.find_label_issues",
mock_find_label_issues,
)

record = rb.TextClassificationRecord(
inputs="mock", prediction=[("mock", 0.1)], annotation="mock"
Expand All @@ -113,25 +130,50 @@ def mock_get_noise_indices(*args, **kwargs):
def test_kwargs(monkeypatch, records):
is_multi_label = records[0].multi_label

def mock_get_noise_indices(s, psx, n_jobs, **kwargs):
assert kwargs == {
"multi_label": is_multi_label,
"sorted_index_method": "normalized_margin",
"mock": "mock",
}
return []

monkeypatch.setattr(
"cleanlab.pruning.get_noise_indices",
mock_get_noise_indices,
)
if parse_version(cleanlab.__version__) < parse_version("2.0"):

with pytest.raises(
ValueError, match="'sorted_index_method' kwarg is not supported"
):
find_label_errors(records=records, sorted_index_method="mock")
def mock_get_noise_indices(s, psx, n_jobs, **kwargs):
assert kwargs == {
"mock": "mock",
"multi_label": is_multi_label,
"sorted_index_method": "normalized_margin",
}
return []

monkeypatch.setattr(
"cleanlab.pruning.get_noise_indices",
mock_get_noise_indices,
)

with pytest.raises(
ValueError, match="'sorted_index_method' kwarg is not supported"
):
find_label_errors(records=records, sorted_index_method="mock")

find_label_errors(records=records, mock="mock")
else:

def mock_find_label_issues(s, psx, n_jobs, **kwargs):
assert kwargs == {
"mock": "mock",
"multi_label": is_multi_label,
"return_indices_ranked_by": "normalized_margin"
if not is_multi_label
else "self_confidence",
}
return []

monkeypatch.setattr(
"cleanlab.filter.find_label_issues",
mock_find_label_issues,
)

with pytest.raises(
ValueError, match="'return_indices_ranked_by' kwarg is not supported"
):
find_label_errors(records=records, return_indices_ranked_by="mock")

find_label_errors(records=records, mock="mock")
find_label_errors(records=records, mock="mock")


def test_construct_s_and_psx(records):
Expand Down