From d7b8aa15c21f1ffdce1032e474f0c38124ba9a45 Mon Sep 17 00:00:00 2001 From: Samuel Fendell Date: Fri, 5 Apr 2024 13:43:16 -0700 Subject: [PATCH] Revert "[PLT-150] Add unified create method for AnnotationImport, MEAPredictionImport, and MALPredictionImport. (#1523)" This reverts commit a07d582bbc3a91fe6c5ae642ef0675a74d7fecd4. --- labelbox/schema/annotation_import.py | 48 +----------- labelbox/schema/model_run.py | 14 ++-- labelbox/utils.py | 4 +- .../annotation_import/test_label_import.py | 76 +++---------------- .../test_mal_prediction_import.py | 58 -------------- 5 files changed, 25 insertions(+), 175 deletions(-) delete mode 100644 tests/data/annotation_import/test_mal_prediction_import.py diff --git a/labelbox/schema/annotation_import.py b/labelbox/schema/annotation_import.py index 68c6eb034..d3c209f46 100644 --- a/labelbox/schema/annotation_import.py +++ b/labelbox/schema/annotation_import.py @@ -3,7 +3,7 @@ import logging import os import time -from typing import Any, BinaryIO, Dict, List, Optional, Union, TYPE_CHECKING, cast +from typing import Any, BinaryIO, Dict, List, Union, TYPE_CHECKING, cast from collections import defaultdict from google.api_core import retry @@ -241,47 +241,7 @@ def parent_id(self) -> str: raise NotImplementedError("Inheriting class must override") -class CreatableAnnotationImport(AnnotationImport): - - @classmethod - def create( - cls, - client: "labelbox.Client", - id: str, - name: str, - path: Optional[str] = None, - url: Optional[str] = None, - labels: Union[List[Dict[str, Any]], List["Label"]] = [] - ) -> "AnnotationImport": - if (not is_exactly_one_set(url, labels, path)): - raise ValueError( - "Must pass in a nonempty argument for one and only one of the following arguments: url, path, predictions" - ) - if url: - return cls.create_from_url(client, id, name, url) - if path: - return cls.create_from_file(client, id, name, path) - return cls.create_from_objects(client, id, name, labels) - - @classmethod - def create_from_url(cls, client: "labelbox.Client", id: str, name: str, - url: str) -> "AnnotationImport": - raise NotImplementedError("Inheriting class must override") - - @classmethod - def create_from_file(cls, client: "labelbox.Client", id: str, name: str, - path: str) -> "AnnotationImport": - raise NotImplementedError("Inheriting class must override") - - @classmethod - def create_from_objects( - cls, client: "labelbox.Client", id: str, name: str, - labels: Union[List[Dict[str, Any]], - List["Label"]]) -> "AnnotationImport": - raise NotImplementedError("Inheriting class must override") - - -class MEAPredictionImport(CreatableAnnotationImport): +class MEAPredictionImport(AnnotationImport): model_run_id = Field.String("model_run_id") @property @@ -518,7 +478,7 @@ def _get_model_run_data_rows_mutation(cls) -> str: }""" % query.results_query_part(cls) -class MALPredictionImport(CreatableAnnotationImport): +class MALPredictionImport(AnnotationImport): project = Relationship.ToOne("Project", cache=True) @property @@ -678,7 +638,7 @@ def _create_mal_import_from_bytes( return cls(client, res["createModelAssistedLabelingPredictionImport"]) -class LabelImport(CreatableAnnotationImport): +class LabelImport(AnnotationImport): project = Relationship.ToOne("Project", cache=True) @property diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 70cc49e39..af3517dde 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -286,17 +286,17 @@ def add_predictions( Returns: AnnotationImport """ - kwargs = dict(client=self.client, id=self.uid, name=name) + kwargs = dict(client=self.client, model_run_id=self.uid, name=name) if isinstance(predictions, str) or isinstance(predictions, Path): if os.path.exists(predictions): - return Entity.MEAPredictionImport.create(path=str(predictions), - **kwargs) + return Entity.MEAPredictionImport.create_from_file( + path=str(predictions), **kwargs) else: - return Entity.MEAPredictionImport.create(url=str(predictions), - **kwargs) + return Entity.MEAPredictionImport.create_from_url( + url=str(predictions), **kwargs) elif isinstance(predictions, Iterable): - return Entity.MEAPredictionImport.create(labels=predictions, - **kwargs) + return Entity.MEAPredictionImport.create_from_objects( + predictions=predictions, **kwargs) else: raise ValueError( f'Invalid predictions given of type: {type(predictions)}') diff --git a/labelbox/utils.py b/labelbox/utils.py index f606932c7..da2dbdec4 100644 --- a/labelbox/utils.py +++ b/labelbox/utils.py @@ -39,8 +39,8 @@ def snake_case(s): return _convert(s, "_", lambda i: False) -def is_exactly_one_set(*args): - return sum([bool(arg) for arg in args]) == 1 +def is_exactly_one_set(x, y): + return not (bool(x) == bool(y)) def is_valid_uri(uri): diff --git a/tests/data/annotation_import/test_label_import.py b/tests/data/annotation_import/test_label_import.py index ddc63d3e0..61c602c52 100644 --- a/tests/data/annotation_import/test_label_import.py +++ b/tests/data/annotation_import/test_label_import.py @@ -1,6 +1,5 @@ import uuid import pytest -from labelbox import parser from labelbox.schema.annotation_import import AnnotationImportState, LabelImport """ @@ -10,19 +9,6 @@ """ -def test_create_with_url_arg(client, configured_project_with_one_data_row, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - label_import = LabelImport.create( - client=client, - id=configured_project_with_one_data_row.uid, - name=name, - url=url) - assert label_import.parent_id == configured_project_with_one_data_row.uid - annotation_import_test_helpers.check_running_state(label_import, name, url) - - def test_create_from_url(client, configured_project_with_one_data_row, annotation_import_test_helpers): name = str(uuid.uuid4()) @@ -36,22 +22,6 @@ def test_create_from_url(client, configured_project_with_one_data_row, annotation_import_test_helpers.check_running_state(label_import, name, url) -def test_create_with_labels_arg(client, configured_project, object_predictions, - annotation_import_test_helpers): - """this test should check running state only to validate running, not completed""" - name = str(uuid.uuid4()) - - label_import = LabelImport.create(client=client, - id=configured_project.uid, - name=name, - labels=object_predictions) - - assert label_import.parent_id == configured_project.uid - annotation_import_test_helpers.check_running_state(label_import, name) - annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) - - def test_create_from_objects(client, configured_project, object_predictions, annotation_import_test_helpers): """this test should check running state only to validate running, not completed""" @@ -69,42 +39,20 @@ def test_create_from_objects(client, configured_project, object_predictions, label_import.input_file_url, object_predictions) -def test_create_with_path_arg(client, tmp_path, project, object_predictions, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - file_name = f"{name}.ndjson" - file_path = tmp_path / file_name - with file_path.open("w") as f: - parser.dump(object_predictions, f) - - label_import = LabelImport.create(client=client, - id=project.uid, - name=name, - path=str(file_path)) - - assert label_import.parent_id == project.uid - annotation_import_test_helpers.check_running_state(label_import, name) - annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) - - -def test_create_from_local_file(client, tmp_path, project, object_predictions, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - file_name = f"{name}.ndjson" - file_path = tmp_path / file_name - with file_path.open("w") as f: - parser.dump(object_predictions, f) +# TODO: add me when we add this ability +# def test_create_from_local_file(client, tmp_path, project, +# object_predictions, annotation_import_test_helpers): +# name = str(uuid.uuid4()) +# file_name = f"{name}.ndjson" +# file_path = tmp_path / file_name +# with file_path.open("w") as f: +# ndjson.dump(object_predictions, f) - label_import = LabelImport.create_from_url(client=client, - project_id=project.uid, - name=name, - url=str(file_path)) +# label_import = LabelImport.create_from_url(client=client, project_id=project.uid, name=name, url=str(file_path)) - assert label_import.parent_id == project.uid - annotation_import_test_helpers.check_running_state(label_import, name) - annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) +# assert label_import.parent_id == project.uid +# annotation_import_test_helpers.check_running_state(label_import, name) +# annotation_import_test_helpers.assert_file_content(label_import.input_file_url, object_predictions) def test_get(client, configured_project_with_one_data_row, diff --git a/tests/data/annotation_import/test_mal_prediction_import.py b/tests/data/annotation_import/test_mal_prediction_import.py deleted file mode 100644 index c261f7065..000000000 --- a/tests/data/annotation_import/test_mal_prediction_import.py +++ /dev/null @@ -1,58 +0,0 @@ -import uuid -import pytest - -from labelbox import parser -from labelbox.schema.annotation_import import MALPredictionImport -""" -- Here we only want to check that the uploads are calling the validation -- Then with unit tests we can check the types of errors raised - -""" - - -def test_create_with_url_arg(client, configured_project_with_one_data_row, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - label_import = MALPredictionImport.create( - client=client, - id=configured_project_with_one_data_row.uid, - name=name, - url=url) - assert label_import.parent_id == configured_project_with_one_data_row.uid - annotation_import_test_helpers.check_running_state(label_import, name, url) - - -def test_create_with_labels_arg(client, configured_project, object_predictions, - annotation_import_test_helpers): - """this test should check running state only to validate running, not completed""" - name = str(uuid.uuid4()) - - label_import = MALPredictionImport.create(client=client, - id=configured_project.uid, - name=name, - labels=object_predictions) - - assert label_import.parent_id == configured_project.uid - annotation_import_test_helpers.check_running_state(label_import, name) - annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) - - -def test_create_with_path_arg(client, tmp_path, project, object_predictions, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - file_name = f"{name}.ndjson" - file_path = tmp_path / file_name - with file_path.open("w") as f: - parser.dump(object_predictions, f) - - label_import = MALPredictionImport.create(client=client, - id=project.uid, - name=name, - path=str(file_path)) - - assert label_import.parent_id == project.uid - annotation_import_test_helpers.check_running_state(label_import, name) - annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions)