Skip to content

Commit

Permalink
fix(cleanlab): set cleanlab n_jobs=1 as default (#1059)
Browse files Browse the repository at this point in the history
* fix(cleanlab): set cleanlab n_jobs=1 as default

* test: update tests

* Apply suggestions from code review

Co-authored-by: David Fidalgo <david@recogn.ai>

Co-authored-by: David Fidalgo <david@recogn.ai>
(cherry picked from commit 04efde8)
  • Loading branch information
frascuchon committed Jan 31, 2022
1 parent 37207bc commit 189cbcb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/rubrix/labeling/text_classification/label_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def find_label_errors(
records: List[TextClassificationRecord],
sort_by: Union[str, SortBy] = "likelihood",
metadata_key: str = "label_error_candidate",
n_jobs: Optional[int] = 1,
**kwargs,
) -> List[TextClassificationRecord]:
"""Finds potential annotation/label errors in your records using [cleanlab](https://github.com/cleanlab/cleanlab).
Expand All @@ -55,6 +56,8 @@ def find_label_errors(
- "prediction": sort the returned records by the probability of the prediction (highest probability first)
- "none": do not sort the returned records
metadata_key: The key added to the record's metadata that holds the order, if ``sort_by`` is not "none".
n_jobs : Number of processing threads used by multiprocessing. If None, uses the number of threads
on your CPU. Defaults to 1, which removes parallel processing.
**kwargs: Passed on to `cleanlab.pruning.get_noise_indices`
Returns:
Expand Down Expand Up @@ -96,7 +99,7 @@ def find_label_errors(
# construct "noisy" label vector and probability matrix of the predictions
s, psx = _construct_s_and_psx(records)

indices = get_noise_indices(s, psx, **kwargs)
indices = get_noise_indices(s, psx, n_jobs=n_jobs, **kwargs)

records_with_label_errors = np.array(records)[indices].tolist()

Expand Down
2 changes: 1 addition & 1 deletion tests/labeling/text_classification/test_label_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def mock_get_noise_indices(*args, **kwargs):
def test_kwargs(monkeypatch, records):
is_multi_label = records[0].multi_label

def mock_get_noise_indices(s, psx, **kwargs):
def mock_get_noise_indices(s, psx, n_jobs, **kwargs):
assert kwargs == {
"multi_label": is_multi_label,
"sorted_index_method": "normalized_margin",
Expand Down

0 comments on commit 189cbcb

Please sign in to comment.