From 867f37701826cc392d37d3b051ebc80d97f2e662 Mon Sep 17 00:00:00 2001 From: David Fidalgo Date: Fri, 28 Jan 2022 15:49:30 +0100 Subject: [PATCH] feat(#932): label models now modify the prediction_agent when calling LabelModel.predict (#1049) * feat: add prediction agent * test: add asserts --- src/rubrix/labeling/text_classification/label_models.py | 8 ++++++++ tests/labeling/text_classification/test_label_models.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/src/rubrix/labeling/text_classification/label_models.py b/src/rubrix/labeling/text_classification/label_models.py index d853049560..747dcafb9f 100644 --- a/src/rubrix/labeling/text_classification/label_models.py +++ b/src/rubrix/labeling/text_classification/label_models.py @@ -73,6 +73,7 @@ def predict( self, include_annotated_records: bool = False, include_abstentions: bool = False, + prediction_agent: str = "LabelModel", **kwargs, ) -> List[TextClassificationRecord]: """Applies the label model. @@ -80,6 +81,7 @@ def predict( Args: include_annotated_records: Whether or not to include annotated records. include_abstentions: Whether or not to include records in the output, for which the label model abstained. + prediction_agent: String used for the ``prediction_agent`` in the returned records. Returns: A list of records that include the predictions of the label model. @@ -282,6 +284,7 @@ def predict( self, include_annotated_records: bool = False, include_abstentions: bool = False, + prediction_agent: str = "Snorkel", tie_break_policy: Union[TieBreakPolicy, str] = "abstain", ) -> List[TextClassificationRecord]: """Returns a list of records that contain the predictions of the label model @@ -289,6 +292,7 @@ def predict( Args: include_annotated_records: Whether or not to include annotated records. include_abstentions: Whether or not to include records in the output, for which the label model abstained. + prediction_agent: String used for the ``prediction_agent`` in the returned records. tie_break_policy: Policy to break ties. You can choose among three policies: - `abstain`: Do not provide any prediction @@ -355,6 +359,7 @@ def predict( ] records_with_prediction[-1].prediction = pred_for_rec + records_with_prediction[-1].prediction_agent = prediction_agent return records_with_prediction @@ -531,6 +536,7 @@ def predict( self, include_annotated_records: bool = False, include_abstentions: bool = False, + prediction_agent: str = "FlyingSquid", verbose: bool = True, tie_break_policy: str = "abstain", ) -> List[TextClassificationRecord]: @@ -539,6 +545,7 @@ def predict( Args: include_annotated_records: Whether or not to include annotated records. include_abstentions: Whether or not to include records in the output, for which the label model abstained. + prediction_agent: String used for the ``prediction_agent`` in the returned records. verbose: If True, print out messages of the progress to stderr. tie_break_policy: Policy to break ties. You can choose among two policies: @@ -612,6 +619,7 @@ def predict( records_with_prediction.append(rec.copy(deep=True)) records_with_prediction[-1].prediction = pred_for_rec + records_with_prediction[-1].prediction_agent = prediction_agent return records_with_prediction diff --git a/tests/labeling/text_classification/test_label_models.py b/tests/labeling/text_classification/test_label_models.py index e7f11647ba..f14dd47b41 100644 --- a/tests/labeling/text_classification/test_label_models.py +++ b/tests/labeling/text_classification/test_label_models.py @@ -259,6 +259,7 @@ def mock_predict(self, L, return_probs, tie_break_policy, *args, **kwargs): tie_break_policy=policy, include_annotated_records=include_annotated_records, include_abstentions=include_abstentions, + prediction_agent="mock_agent", ) assert len(records) == expected[0] assert [ @@ -267,6 +268,7 @@ def mock_predict(self, L, return_probs, tie_break_policy, *args, **kwargs): assert [ rec.prediction[0][1] if rec.prediction else None for rec in records ] == expected[2] + assert records[0].prediction_agent == "mock_agent" @pytest.mark.parametrize("policy,expected", [("abstain", 0.5), ("random", 2.0 / 3)]) def test_score(self, monkeypatch, weak_labels, policy, expected): @@ -455,12 +457,14 @@ def __call__(cls, L_matrix, verbose): include_annotated_records=include_annotated_records, include_abstentions=include_abstentions, verbose=verbose, + prediction_agent="mock_agent", ) assert MockPredict.calls_count == 3 assert len(records) == expected["nr_of_records"] if records: assert records[0].prediction == expected["prediction"] + assert records[0].prediction_agent == "mock_agent" def test_predict_binary(self, monkeypatch, weak_labels): class MockPredict: