diff --git a/src/rubrix/labeling/text_classification/weak_labels.py b/src/rubrix/labeling/text_classification/weak_labels.py index b4eccb2402..6403c75af1 100644 --- a/src/rubrix/labeling/text_classification/weak_labels.py +++ b/src/rubrix/labeling/text_classification/weak_labels.py @@ -851,9 +851,7 @@ def summary( ) # correct/incorrect - correct, incorrect = self._compute_correct_incorrect( - has_weak_label, annotation - ) + correct, incorrect = self._compute_correct_incorrect(annotation) # precision precision = correct / (correct + incorrect) @@ -881,27 +879,23 @@ def summary( ) def _compute_correct_incorrect( - self, has_weak_label: np.ndarray, annotation: np.ndarray + self, annotation: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """Helper method to compute the correctly and incorrectly predicted annotations by the rules""" # transform annotation to tensor - annotation = np.repeat(annotation, len(self._rules)).reshape(self._matrix.shape) + annotation = np.repeat(annotation, len(self._rules), axis=0).reshape( + self._matrix.shape + ) # correct, we don't want to count the "correct non predictions" - correct_with_abstain = ((annotation == self._matrix) & (self._matrix == 1)).sum( - 2 - ) - correct = np.where(has_weak_label, correct_with_abstain, False).sum(axis=0) + correct = ((annotation == self._matrix) & (self._matrix == 1)).sum(2).sum(0) # incorrect, we don't want to count the "misses", since we focus on precision, not recall - incorrect_with_abstain = ( - (annotation != self._matrix) & (self._matrix == 1) - ).sum(2) - incorrect = np.where( - has_weak_label & (annotation.sum(2) >= 0), - incorrect_with_abstain, - False, - ).sum(axis=0) + incorrect = ( + ((annotation != self._matrix) & (self._matrix == 1) & (annotation != -1)) + .sum(2) + .sum(0) + ) # add totals at the end return np.append(correct, correct.sum()), np.append(incorrect, incorrect.sum()) diff --git a/tests/labeling/text_classification/test_weak_labels.py b/tests/labeling/text_classification/test_weak_labels.py index 77c4e29d50..8a77e99f89 100644 --- a/tests/labeling/text_classification/test_weak_labels.py +++ b/tests/labeling/text_classification/test_weak_labels.py @@ -692,6 +692,28 @@ def mock_apply(self, *args, **kwargs): ) pd.testing.assert_frame_equal(summary, expected) + def test_compute_correct_incorrect(self, monkeypatch): + def mock_load(*args, **kwargs): + return [TextClassificationRecord(inputs="mock")] + + monkeypatch.setattr( + "rubrix.labeling.text_classification.weak_labels.load", mock_load + ) + + def mock_apply(self, *args, **kwargs): + weak_label_matrix = np.array([[[1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.short) + return weak_label_matrix, None, None + + monkeypatch.setattr(WeakMultiLabels, "_apply_rules", mock_apply) + + weak_labels = WeakMultiLabels(rules=[lambda x: "mock"] * 2, dataset="mock") + correct, incorrect = weak_labels._compute_correct_incorrect( + annotation=np.array([[1, 0, 1, 0]]) + ) + + assert np.allclose(correct, np.array([2, 0, 2])) + assert np.allclose(incorrect, np.array([0, 2, 2])) + def test_show_records(self, monkeypatch, multilabel_rules): def mock_load(*args, **kwargs): return [