diff --git a/labelbox/schema/annotation_import.py b/labelbox/schema/annotation_import.py index d3c209f46..68c6eb034 100644 --- a/labelbox/schema/annotation_import.py +++ b/labelbox/schema/annotation_import.py @@ -3,7 +3,7 @@ import logging import os import time -from typing import Any, BinaryIO, Dict, List, Union, TYPE_CHECKING, cast +from typing import Any, BinaryIO, Dict, List, Optional, Union, TYPE_CHECKING, cast from collections import defaultdict from google.api_core import retry @@ -241,7 +241,47 @@ def parent_id(self) -> str: raise NotImplementedError("Inheriting class must override") -class MEAPredictionImport(AnnotationImport): +class CreatableAnnotationImport(AnnotationImport): + + @classmethod + def create( + cls, + client: "labelbox.Client", + id: str, + name: str, + path: Optional[str] = None, + url: Optional[str] = None, + labels: Union[List[Dict[str, Any]], List["Label"]] = [] + ) -> "AnnotationImport": + if (not is_exactly_one_set(url, labels, path)): + raise ValueError( + "Must pass in a nonempty argument for one and only one of the following arguments: url, path, predictions" + ) + if url: + return cls.create_from_url(client, id, name, url) + if path: + return cls.create_from_file(client, id, name, path) + return cls.create_from_objects(client, id, name, labels) + + @classmethod + def create_from_url(cls, client: "labelbox.Client", id: str, name: str, + url: str) -> "AnnotationImport": + raise NotImplementedError("Inheriting class must override") + + @classmethod + def create_from_file(cls, client: "labelbox.Client", id: str, name: str, + path: str) -> "AnnotationImport": + raise NotImplementedError("Inheriting class must override") + + @classmethod + def create_from_objects( + cls, client: "labelbox.Client", id: str, name: str, + labels: Union[List[Dict[str, Any]], + List["Label"]]) -> "AnnotationImport": + raise NotImplementedError("Inheriting class must override") + + +class MEAPredictionImport(CreatableAnnotationImport): model_run_id = Field.String("model_run_id") @property @@ -478,7 +518,7 @@ def _get_model_run_data_rows_mutation(cls) -> str: }""" % query.results_query_part(cls) -class MALPredictionImport(AnnotationImport): +class MALPredictionImport(CreatableAnnotationImport): project = Relationship.ToOne("Project", cache=True) @property @@ -638,7 +678,7 @@ def _create_mal_import_from_bytes( return cls(client, res["createModelAssistedLabelingPredictionImport"]) -class LabelImport(AnnotationImport): +class LabelImport(CreatableAnnotationImport): project = Relationship.ToOne("Project", cache=True) @property diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index af3517dde..70cc49e39 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -286,17 +286,17 @@ def add_predictions( Returns: AnnotationImport """ - kwargs = dict(client=self.client, model_run_id=self.uid, name=name) + kwargs = dict(client=self.client, id=self.uid, name=name) if isinstance(predictions, str) or isinstance(predictions, Path): if os.path.exists(predictions): - return Entity.MEAPredictionImport.create_from_file( - path=str(predictions), **kwargs) + return Entity.MEAPredictionImport.create(path=str(predictions), + **kwargs) else: - return Entity.MEAPredictionImport.create_from_url( - url=str(predictions), **kwargs) + return Entity.MEAPredictionImport.create(url=str(predictions), + **kwargs) elif isinstance(predictions, Iterable): - return Entity.MEAPredictionImport.create_from_objects( - predictions=predictions, **kwargs) + return Entity.MEAPredictionImport.create(labels=predictions, + **kwargs) else: raise ValueError( f'Invalid predictions given of type: {type(predictions)}') diff --git a/labelbox/utils.py b/labelbox/utils.py index da2dbdec4..f606932c7 100644 --- a/labelbox/utils.py +++ b/labelbox/utils.py @@ -39,8 +39,8 @@ def snake_case(s): return _convert(s, "_", lambda i: False) -def is_exactly_one_set(x, y): - return not (bool(x) == bool(y)) +def is_exactly_one_set(*args): + return sum([bool(arg) for arg in args]) == 1 def is_valid_uri(uri): diff --git a/tests/data/annotation_import/test_label_import.py b/tests/data/annotation_import/test_label_import.py index 61c602c52..ddc63d3e0 100644 --- a/tests/data/annotation_import/test_label_import.py +++ b/tests/data/annotation_import/test_label_import.py @@ -1,5 +1,6 @@ import uuid import pytest +from labelbox import parser from labelbox.schema.annotation_import import AnnotationImportState, LabelImport """ @@ -9,6 +10,19 @@ """ +def test_create_with_url_arg(client, configured_project_with_one_data_row, + annotation_import_test_helpers): + name = str(uuid.uuid4()) + url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" + label_import = LabelImport.create( + client=client, + id=configured_project_with_one_data_row.uid, + name=name, + url=url) + assert label_import.parent_id == configured_project_with_one_data_row.uid + annotation_import_test_helpers.check_running_state(label_import, name, url) + + def test_create_from_url(client, configured_project_with_one_data_row, annotation_import_test_helpers): name = str(uuid.uuid4()) @@ -22,6 +36,22 @@ def test_create_from_url(client, configured_project_with_one_data_row, annotation_import_test_helpers.check_running_state(label_import, name, url) +def test_create_with_labels_arg(client, configured_project, object_predictions, + annotation_import_test_helpers): + """this test should check running state only to validate running, not completed""" + name = str(uuid.uuid4()) + + label_import = LabelImport.create(client=client, + id=configured_project.uid, + name=name, + labels=object_predictions) + + assert label_import.parent_id == configured_project.uid + annotation_import_test_helpers.check_running_state(label_import, name) + annotation_import_test_helpers.assert_file_content( + label_import.input_file_url, object_predictions) + + def test_create_from_objects(client, configured_project, object_predictions, annotation_import_test_helpers): """this test should check running state only to validate running, not completed""" @@ -39,20 +69,42 @@ def test_create_from_objects(client, configured_project, object_predictions, label_import.input_file_url, object_predictions) -# TODO: add me when we add this ability -# def test_create_from_local_file(client, tmp_path, project, -# object_predictions, annotation_import_test_helpers): -# name = str(uuid.uuid4()) -# file_name = f"{name}.ndjson" -# file_path = tmp_path / file_name -# with file_path.open("w") as f: -# ndjson.dump(object_predictions, f) +def test_create_with_path_arg(client, tmp_path, project, object_predictions, + annotation_import_test_helpers): + name = str(uuid.uuid4()) + file_name = f"{name}.ndjson" + file_path = tmp_path / file_name + with file_path.open("w") as f: + parser.dump(object_predictions, f) + + label_import = LabelImport.create(client=client, + id=project.uid, + name=name, + path=str(file_path)) + + assert label_import.parent_id == project.uid + annotation_import_test_helpers.check_running_state(label_import, name) + annotation_import_test_helpers.assert_file_content( + label_import.input_file_url, object_predictions) + + +def test_create_from_local_file(client, tmp_path, project, object_predictions, + annotation_import_test_helpers): + name = str(uuid.uuid4()) + file_name = f"{name}.ndjson" + file_path = tmp_path / file_name + with file_path.open("w") as f: + parser.dump(object_predictions, f) -# label_import = LabelImport.create_from_url(client=client, project_id=project.uid, name=name, url=str(file_path)) + label_import = LabelImport.create_from_url(client=client, + project_id=project.uid, + name=name, + url=str(file_path)) -# assert label_import.parent_id == project.uid -# annotation_import_test_helpers.check_running_state(label_import, name) -# annotation_import_test_helpers.assert_file_content(label_import.input_file_url, object_predictions) + assert label_import.parent_id == project.uid + annotation_import_test_helpers.check_running_state(label_import, name) + annotation_import_test_helpers.assert_file_content( + label_import.input_file_url, object_predictions) def test_get(client, configured_project_with_one_data_row, diff --git a/tests/data/annotation_import/test_mal_prediction_import.py b/tests/data/annotation_import/test_mal_prediction_import.py new file mode 100644 index 000000000..c261f7065 --- /dev/null +++ b/tests/data/annotation_import/test_mal_prediction_import.py @@ -0,0 +1,58 @@ +import uuid +import pytest + +from labelbox import parser +from labelbox.schema.annotation_import import MALPredictionImport +""" +- Here we only want to check that the uploads are calling the validation +- Then with unit tests we can check the types of errors raised + +""" + + +def test_create_with_url_arg(client, configured_project_with_one_data_row, + annotation_import_test_helpers): + name = str(uuid.uuid4()) + url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" + label_import = MALPredictionImport.create( + client=client, + id=configured_project_with_one_data_row.uid, + name=name, + url=url) + assert label_import.parent_id == configured_project_with_one_data_row.uid + annotation_import_test_helpers.check_running_state(label_import, name, url) + + +def test_create_with_labels_arg(client, configured_project, object_predictions, + annotation_import_test_helpers): + """this test should check running state only to validate running, not completed""" + name = str(uuid.uuid4()) + + label_import = MALPredictionImport.create(client=client, + id=configured_project.uid, + name=name, + labels=object_predictions) + + assert label_import.parent_id == configured_project.uid + annotation_import_test_helpers.check_running_state(label_import, name) + annotation_import_test_helpers.assert_file_content( + label_import.input_file_url, object_predictions) + + +def test_create_with_path_arg(client, tmp_path, project, object_predictions, + annotation_import_test_helpers): + name = str(uuid.uuid4()) + file_name = f"{name}.ndjson" + file_path = tmp_path / file_name + with file_path.open("w") as f: + parser.dump(object_predictions, f) + + label_import = MALPredictionImport.create(client=client, + id=project.uid, + name=name, + path=str(file_path)) + + assert label_import.parent_id == project.uid + annotation_import_test_helpers.check_running_state(label_import, name) + annotation_import_test_helpers.assert_file_content( + label_import.input_file_url, object_predictions)