From d3cbaea2e4be466c16545b08673ba4ed667a032b Mon Sep 17 00:00:00 2001 From: David Fidalgo Date: Mon, 17 Jan 2022 10:30:16 +0100 Subject: [PATCH] feat(#955): add default for `rules` in WeakLabels (#976) * feat: load rules of dataset by default * test: add tests * test: fix tests --- .../text_classification/weak_labels.py | 39 +++++++++++++++---- .../text_classification/test_weak_labels.py | 25 +++++++++--- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/src/rubrix/labeling/text_classification/weak_labels.py b/src/rubrix/labeling/text_classification/weak_labels.py index 3c2947c1b2..638589872d 100644 --- a/src/rubrix/labeling/text_classification/weak_labels.py +++ b/src/rubrix/labeling/text_classification/weak_labels.py @@ -21,15 +21,16 @@ from rubrix import load from rubrix.client.models import TextClassificationRecord -from rubrix.labeling.text_classification.rule import Rule +from rubrix.labeling.text_classification.rule import Rule, load_rules class WeakLabels: """Computes the weak labels of a dataset by applying a given list of rules. Args: - rules: A list of rules (labeling functions). They must return a string, or ``None`` in case of abstention. dataset: Name of the dataset to which the rules will be applied. + rules: A list of rules (labeling functions). They must return a string, or ``None`` in case of abstention. + If None, we will use the rules of the dataset (Default). ids: An optional list of record ids to filter the dataset before applying the rules. query: An optional ElasticSearch query with the `query string syntax `_ @@ -38,6 +39,7 @@ class WeakLabels: abstention (e.g. ``{None: -1}``). By default, we will build a mapping on the fly when applying the rules. Raises: + NoRulesFoundError: When you do not provide rules, and the dataset has no rules either. DuplicatedRuleNameError: When you provided multiple rules with the same name. NoRecordsFoundError: When the filtered dataset is empty. MultiLabelError: When trying to get weak labels for a multi-label text classification task. @@ -45,7 +47,13 @@ class WeakLabels: weak label or annotation label is not present in its keys. Examples: - Get the weak label matrix and a summary of the applied rules: + Get the weak label matrix from a dataset with rules: + + >>> weak_labels = WeakLabels(dataset="my_dataset") + >>> weak_labels.matrix() + >>> weak_labels.summary() + + Get the weak label matrix from rules defined in Python: >>> def awesome_rule(record: TextClassificationRecord) -> str: ... return "Positive" if "awesome" in record.inputs["text"] else None @@ -54,24 +62,37 @@ class WeakLabels: >>> weak_labels.matrix() >>> weak_labels.summary() - Use snorkel's LabelModel: + Use the WeakLabels object with snorkel's LabelModel: >>> from snorkel.labeling.model import LabelModel >>> label_model = LabelModel() >>> label_model.fit(L_train=weak_labels.matrix(has_annotation=False)) >>> label_model.score(L=weak_labels.matrix(has_annotation=True), Y=weak_labels.annotation()) >>> label_model.predict(L=weak_labels.matrix(has_annotation=False)) + + For a builtin integration with Snorkel, see `rubrix.labeling.text_classification.Snorkel`. """ def __init__( self, - rules: List[Callable], dataset: str, + rules: Optional[List[Callable]] = None, ids: Optional[List[Union[int, str]]] = None, query: Optional[str] = None, label2int: Optional[Dict[Optional[str], int]] = None, ): - self._rules = rules + if not isinstance(dataset, str): + raise TypeError( + f"The name of the dataset must be a string, but you provided: {dataset}" + ) + self._dataset = dataset + + self._rules = rules or load_rules(dataset) + if self._rules == []: + raise NoRulesFoundError( + f"No rules were found in the given dataset '{dataset}'" + ) + self._rules_index2name = { # covers our Rule class, snorkel's LabelingFunction class and arbitrary methods index: ( @@ -97,8 +118,6 @@ def __init__( val: key for key, val in self._rules_index2name.items() } - self._dataset = dataset - # load records and check compatibility self._records: List[TextClassificationRecord] = load( dataset, query=query, ids=ids, as_pandas=False @@ -499,6 +518,10 @@ class WeakLabelsError(Exception): pass +class NoRulesFoundError(WeakLabelsError): + pass + + class DuplicatedRuleNameError(WeakLabelsError): pass diff --git a/tests/labeling/text_classification/test_weak_labels.py b/tests/labeling/text_classification/test_weak_labels.py index 490d4a6861..5531e4e608 100644 --- a/tests/labeling/text_classification/test_weak_labels.py +++ b/tests/labeling/text_classification/test_weak_labels.py @@ -30,6 +30,7 @@ MissingLabelError, MultiLabelError, NoRecordsFoundError, + NoRulesFoundError, WeakLabels, ) from tests.server.test_helpers import client @@ -105,7 +106,7 @@ def mock_load(*args, **kwargs): ) with pytest.raises(MultiLabelError): - WeakLabels(rules=[], dataset="mock") + WeakLabels(rules=[lambda x: None], dataset="mock") def test_no_records_found_error(monkeypatch): @@ -119,21 +120,21 @@ def mock_load(*args, **kwargs): with pytest.raises( NoRecordsFoundError, match="No records found in dataset 'mock'." ): - WeakLabels(rules=[], dataset="mock") + WeakLabels(rules=[lambda x: None], dataset="mock") with pytest.raises( NoRecordsFoundError, match="No records found in dataset 'mock' with query 'mock'.", ): - WeakLabels(rules=[], dataset="mock", query="mock") + WeakLabels(rules=[lambda x: None], dataset="mock", query="mock") with pytest.raises( NoRecordsFoundError, match="No records found in dataset 'mock' with ids \[-1\]." ): - WeakLabels(rules=[], dataset="mock", ids=[-1]) + WeakLabels(rules=[lambda x: None], dataset="mock", ids=[-1]) with pytest.raises( NoRecordsFoundError, match="No records found in dataset 'mock' with query 'mock' and with ids \[-1\].", ): - WeakLabels(rules=[], dataset="mock", query="mock", ids=[-1]) + WeakLabels(rules=[lambda x: None], dataset="mock", query="mock", ids=[-1]) @pytest.mark.parametrize( @@ -406,3 +407,17 @@ def mock_apply(self, *args, **kwargs): weak_labels.change_mapping(old_mapping) assert (weak_labels.matrix() == old_wlm).all() + + +def test_dataset_type_error(): + with pytest.raises(TypeError, match="must be a string, but you provided"): + WeakLabels([1, 2, 3]) + + +def test_norulesfounderror(monkeypatch): + monkeypatch.setattr( + "rubrix.labeling.text_classification.weak_labels.load_rules", lambda x: [] + ) + + with pytest.raises(NoRulesFoundError, match="No rules were found"): + WeakLabels("mock")