Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions labelbox/schema/annotation_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
import time
from typing import Any, BinaryIO, Dict, List, Union, TYPE_CHECKING, cast
from typing import Any, BinaryIO, Dict, List, Optional, Union, TYPE_CHECKING, cast
from collections import defaultdict

from google.api_core import retry
Expand Down Expand Up @@ -241,7 +241,47 @@ def parent_id(self) -> str:
raise NotImplementedError("Inheriting class must override")


class MEAPredictionImport(AnnotationImport):
class CreatableAnnotationImport(AnnotationImport):

@classmethod
def create(
cls,
client: "labelbox.Client",
id: str,
name: str,
path: Optional[str] = None,
url: Optional[str] = None,
labels: Union[List[Dict[str, Any]], List["Label"]] = []
) -> "AnnotationImport":
if (not is_exactly_one_set(url, labels, path)):
raise ValueError(
"Must pass in a nonempty argument for one and only one of the following arguments: url, path, predictions"
)
if url:
return cls.create_from_url(client, id, name, url)
if path:
return cls.create_from_file(client, id, name, path)
return cls.create_from_objects(client, id, name, labels)

@classmethod
def create_from_url(cls, client: "labelbox.Client", id: str, name: str,
url: str) -> "AnnotationImport":
raise NotImplementedError("Inheriting class must override")

@classmethod
def create_from_file(cls, client: "labelbox.Client", id: str, name: str,
path: str) -> "AnnotationImport":
raise NotImplementedError("Inheriting class must override")

@classmethod
def create_from_objects(
cls, client: "labelbox.Client", id: str, name: str,
labels: Union[List[Dict[str, Any]],
List["Label"]]) -> "AnnotationImport":
raise NotImplementedError("Inheriting class must override")


class MEAPredictionImport(CreatableAnnotationImport):
model_run_id = Field.String("model_run_id")

@property
Expand Down Expand Up @@ -478,7 +518,7 @@ def _get_model_run_data_rows_mutation(cls) -> str:
}""" % query.results_query_part(cls)


class MALPredictionImport(AnnotationImport):
class MALPredictionImport(CreatableAnnotationImport):
project = Relationship.ToOne("Project", cache=True)

@property
Expand Down Expand Up @@ -638,7 +678,7 @@ def _create_mal_import_from_bytes(
return cls(client, res["createModelAssistedLabelingPredictionImport"])


class LabelImport(AnnotationImport):
class LabelImport(CreatableAnnotationImport):
project = Relationship.ToOne("Project", cache=True)

@property
Expand Down
14 changes: 7 additions & 7 deletions labelbox/schema/model_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,17 @@ def add_predictions(
Returns:
AnnotationImport
"""
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
kwargs = dict(client=self.client, id=self.uid, name=name)
if isinstance(predictions, str) or isinstance(predictions, Path):
if os.path.exists(predictions):
return Entity.MEAPredictionImport.create_from_file(
path=str(predictions), **kwargs)
return Entity.MEAPredictionImport.create(path=str(predictions),
**kwargs)
else:
return Entity.MEAPredictionImport.create_from_url(
url=str(predictions), **kwargs)
return Entity.MEAPredictionImport.create(url=str(predictions),
**kwargs)
elif isinstance(predictions, Iterable):
return Entity.MEAPredictionImport.create_from_objects(
predictions=predictions, **kwargs)
return Entity.MEAPredictionImport.create(labels=predictions,
**kwargs)
else:
raise ValueError(
f'Invalid predictions given of type: {type(predictions)}')
Expand Down
4 changes: 2 additions & 2 deletions labelbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def snake_case(s):
return _convert(s, "_", lambda i: False)


def is_exactly_one_set(x, y):
return not (bool(x) == bool(y))
def is_exactly_one_set(*args):
return sum([bool(arg) for arg in args]) == 1


def is_valid_uri(uri):
Expand Down
76 changes: 64 additions & 12 deletions tests/data/annotation_import/test_label_import.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid
import pytest
from labelbox import parser

from labelbox.schema.annotation_import import AnnotationImportState, LabelImport
"""
Expand All @@ -9,6 +10,19 @@
"""


def test_create_with_url_arg(client, configured_project_with_one_data_row,
annotation_import_test_helpers):
name = str(uuid.uuid4())
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
label_import = LabelImport.create(
client=client,
id=configured_project_with_one_data_row.uid,
name=name,
url=url)
assert label_import.parent_id == configured_project_with_one_data_row.uid
annotation_import_test_helpers.check_running_state(label_import, name, url)


def test_create_from_url(client, configured_project_with_one_data_row,
annotation_import_test_helpers):
name = str(uuid.uuid4())
Expand All @@ -22,6 +36,22 @@ def test_create_from_url(client, configured_project_with_one_data_row,
annotation_import_test_helpers.check_running_state(label_import, name, url)


def test_create_with_labels_arg(client, configured_project, object_predictions,
annotation_import_test_helpers):
"""this test should check running state only to validate running, not completed"""
name = str(uuid.uuid4())

label_import = LabelImport.create(client=client,
id=configured_project.uid,
name=name,
labels=object_predictions)

assert label_import.parent_id == configured_project.uid
annotation_import_test_helpers.check_running_state(label_import, name)
annotation_import_test_helpers.assert_file_content(
label_import.input_file_url, object_predictions)


def test_create_from_objects(client, configured_project, object_predictions,
annotation_import_test_helpers):
"""this test should check running state only to validate running, not completed"""
Expand All @@ -39,20 +69,42 @@ def test_create_from_objects(client, configured_project, object_predictions,
label_import.input_file_url, object_predictions)


# TODO: add me when we add this ability
# def test_create_from_local_file(client, tmp_path, project,
# object_predictions, annotation_import_test_helpers):
# name = str(uuid.uuid4())
# file_name = f"{name}.ndjson"
# file_path = tmp_path / file_name
# with file_path.open("w") as f:
# ndjson.dump(object_predictions, f)
def test_create_with_path_arg(client, tmp_path, project, object_predictions,
annotation_import_test_helpers):
name = str(uuid.uuid4())
file_name = f"{name}.ndjson"
file_path = tmp_path / file_name
with file_path.open("w") as f:
parser.dump(object_predictions, f)

label_import = LabelImport.create(client=client,
id=project.uid,
name=name,
path=str(file_path))

assert label_import.parent_id == project.uid
annotation_import_test_helpers.check_running_state(label_import, name)
annotation_import_test_helpers.assert_file_content(
label_import.input_file_url, object_predictions)


def test_create_from_local_file(client, tmp_path, project, object_predictions,
annotation_import_test_helpers):
name = str(uuid.uuid4())
file_name = f"{name}.ndjson"
file_path = tmp_path / file_name
with file_path.open("w") as f:
parser.dump(object_predictions, f)

# label_import = LabelImport.create_from_url(client=client, project_id=project.uid, name=name, url=str(file_path))
label_import = LabelImport.create_from_url(client=client,
project_id=project.uid,
name=name,
url=str(file_path))

# assert label_import.parent_id == project.uid
# annotation_import_test_helpers.check_running_state(label_import, name)
# annotation_import_test_helpers.assert_file_content(label_import.input_file_url, object_predictions)
assert label_import.parent_id == project.uid
annotation_import_test_helpers.check_running_state(label_import, name)
annotation_import_test_helpers.assert_file_content(
label_import.input_file_url, object_predictions)


def test_get(client, configured_project_with_one_data_row,
Expand Down
58 changes: 58 additions & 0 deletions tests/data/annotation_import/test_mal_prediction_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import uuid
import pytest

from labelbox import parser
from labelbox.schema.annotation_import import MALPredictionImport
"""
- Here we only want to check that the uploads are calling the validation
- Then with unit tests we can check the types of errors raised

"""


def test_create_with_url_arg(client, configured_project_with_one_data_row,
annotation_import_test_helpers):
name = str(uuid.uuid4())
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
label_import = MALPredictionImport.create(
client=client,
id=configured_project_with_one_data_row.uid,
name=name,
url=url)
assert label_import.parent_id == configured_project_with_one_data_row.uid
annotation_import_test_helpers.check_running_state(label_import, name, url)


def test_create_with_labels_arg(client, configured_project, object_predictions,
annotation_import_test_helpers):
"""this test should check running state only to validate running, not completed"""
name = str(uuid.uuid4())

label_import = MALPredictionImport.create(client=client,
id=configured_project.uid,
name=name,
labels=object_predictions)

assert label_import.parent_id == configured_project.uid
annotation_import_test_helpers.check_running_state(label_import, name)
annotation_import_test_helpers.assert_file_content(
label_import.input_file_url, object_predictions)


def test_create_with_path_arg(client, tmp_path, project, object_predictions,
annotation_import_test_helpers):
name = str(uuid.uuid4())
file_name = f"{name}.ndjson"
file_path = tmp_path / file_name
with file_path.open("w") as f:
parser.dump(object_predictions, f)

label_import = MALPredictionImport.create(client=client,
id=project.uid,
name=name,
path=str(file_path))

assert label_import.parent_id == project.uid
annotation_import_test_helpers.check_running_state(label_import, name)
annotation_import_test_helpers.assert_file_content(
label_import.input_file_url, object_predictions)