Skip to content

Commit

Permalink
fix: computation of correct/incorrect in weak multi label summary (#1237
Browse files Browse the repository at this point in the history
)

* fix: computation of correct/incorrect in weak multi label summary

* chore: simplify correct/incorrect computation

* test: add additional test
  • Loading branch information
David Fidalgo committed Mar 9, 2022
1 parent 630091f commit 2167296
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
28 changes: 11 additions & 17 deletions src/rubrix/labeling/text_classification/weak_labels.py
Expand Up @@ -851,9 +851,7 @@ def summary(
)

# correct/incorrect
correct, incorrect = self._compute_correct_incorrect(
has_weak_label, annotation
)
correct, incorrect = self._compute_correct_incorrect(annotation)

# precision
precision = correct / (correct + incorrect)
Expand Down Expand Up @@ -881,27 +879,23 @@ def summary(
)

def _compute_correct_incorrect(
self, has_weak_label: np.ndarray, annotation: np.ndarray
self, annotation: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Helper method to compute the correctly and incorrectly predicted annotations by the rules"""
# transform annotation to tensor
annotation = np.repeat(annotation, len(self._rules)).reshape(self._matrix.shape)
annotation = np.repeat(annotation, len(self._rules), axis=0).reshape(
self._matrix.shape
)

# correct, we don't want to count the "correct non predictions"
correct_with_abstain = ((annotation == self._matrix) & (self._matrix == 1)).sum(
2
)
correct = np.where(has_weak_label, correct_with_abstain, False).sum(axis=0)
correct = ((annotation == self._matrix) & (self._matrix == 1)).sum(2).sum(0)

# incorrect, we don't want to count the "misses", since we focus on precision, not recall
incorrect_with_abstain = (
(annotation != self._matrix) & (self._matrix == 1)
).sum(2)
incorrect = np.where(
has_weak_label & (annotation.sum(2) >= 0),
incorrect_with_abstain,
False,
).sum(axis=0)
incorrect = (
((annotation != self._matrix) & (self._matrix == 1) & (annotation != -1))
.sum(2)
.sum(0)
)

# add totals at the end
return np.append(correct, correct.sum()), np.append(incorrect, incorrect.sum())
Expand Down
22 changes: 22 additions & 0 deletions tests/labeling/text_classification/test_weak_labels.py
Expand Up @@ -692,6 +692,28 @@ def mock_apply(self, *args, **kwargs):
)
pd.testing.assert_frame_equal(summary, expected)

def test_compute_correct_incorrect(self, monkeypatch):
def mock_load(*args, **kwargs):
return [TextClassificationRecord(inputs="mock")]

monkeypatch.setattr(
"rubrix.labeling.text_classification.weak_labels.load", mock_load
)

def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array([[[1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.short)
return weak_label_matrix, None, None

monkeypatch.setattr(WeakMultiLabels, "_apply_rules", mock_apply)

weak_labels = WeakMultiLabels(rules=[lambda x: "mock"] * 2, dataset="mock")
correct, incorrect = weak_labels._compute_correct_incorrect(
annotation=np.array([[1, 0, 1, 0]])
)

assert np.allclose(correct, np.array([2, 0, 2]))
assert np.allclose(incorrect, np.array([0, 2, 2]))

def test_show_records(self, monkeypatch, multilabel_rules):
def mock_load(*args, **kwargs):
return [
Expand Down

0 comments on commit 2167296

Please sign in to comment.