Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions labelbox/schema/model_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,43 @@ class Status(Enum):
COMPLETE = "COMPLETE"
FAILED = "FAILED"

def upsert_labels(self, label_ids, timeout_seconds=3600):
def upsert_labels(self,
label_ids: Optional[List[str]] = None,
project_id: Optional[str] = None,
timeout_seconds=3600):
""" Adds data rows and labels to a Model Run
Args:
label_ids (list): label ids to insert
project_id (string): project uuid, all project labels will be uploaded
Either label_ids OR project_id is required but NOT both
timeout_seconds (float): Max waiting time, in seconds.
Returns:
ID of newly generated async task

"""

if len(label_ids) < 1:
raise ValueError("Must provide at least one label id")
use_label_ids = label_ids is not None and len(label_ids) > 0
use_project_id = project_id is not None

if not use_label_ids and not use_project_id:
raise ValueError(
"Must provide at least one label id or a project id")

if use_label_ids and use_project_id:
raise ValueError("Must only one of label ids, project id")

if use_label_ids:
return self._upsert_labels_by_label_ids(label_ids, timeout_seconds)
else: # use_project_id
return self._upsert_labels_by_project_id(project_id,
timeout_seconds)

def _upsert_labels_by_label_ids(self, label_ids: List[str],
timeout_seconds: int):
mutation_name = 'createMEAModelRunLabelRegistrationTask'
create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
""" % (mutation_name)
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
""" % (mutation_name)

res = self.client.execute(create_task_query_str, {
'modelRunId': self.uid,
Expand All @@ -80,6 +101,29 @@ def upsert_labels(self, label_ids, timeout_seconds=3600):
}})['MEALabelRegistrationTaskStatus'],
timeout_seconds=timeout_seconds)

def _upsert_labels_by_project_id(self, project_id: str,
timeout_seconds: int):
mutation_name = 'createMEAModelRunProjectLabelRegistrationTask'
create_task_query_str = """mutation createMEAModelRunProjectLabelRegistrationTaskPyApi($modelRunId: ID!, $projectId : ID!) {
%s(where : { modelRunId : $modelRunId, projectId: $projectId})}
""" % (mutation_name)

res = self.client.execute(create_task_query_str, {
'modelRunId': self.uid,
'projectId': project_id
})
task_id = res[mutation_name]

status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){
MEALabelRegistrationTaskStatus(where: $where) {status errorMessage}
}
"""
return self._wait_until_done(lambda: self.client.execute(
status_query_str, {'where': {
'id': task_id
}})['MEALabelRegistrationTaskStatus'],
timeout_seconds=timeout_seconds)

def upsert_data_rows(self,
data_row_ids=None,
global_keys=None,
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/annotation_import/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,22 @@ def model_run_with_model_run_data_rows(client, configured_project,
# TODO: Delete resources when that is possible ..


@pytest.fixture
def model_run_with_all_project_labels(client, configured_project,
model_run_predictions, model_run):
configured_project.enable_model_assisted_labeling()

upload_task = LabelImport.create_from_objects(
client, configured_project.uid, f"label-import-{uuid.uuid4()}",
model_run_predictions)
upload_task.wait_until_done()
model_run.upsert_labels(project_id=configured_project.uid)
time.sleep(3)
yield model_run
model_run.delete()
# TODO: Delete resources when that is possible ..


class AnnotationImportTestHelpers:

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,44 @@ def test_create_from_objects(model_run_with_model_run_data_rows,
annotation_import.wait_until_done()


def test_create_from_objects_all_project_labels(
model_run_with_all_project_labels, object_predictions,
annotation_import_test_helpers):
name = str(uuid.uuid4())

annotation_import = model_run_with_all_project_labels.add_predictions(
name=name, predictions=object_predictions)

assert annotation_import.model_run_id == model_run_with_all_project_labels.uid
annotation_import_test_helpers.check_running_state(annotation_import, name)
annotation_import_test_helpers.assert_file_content(
annotation_import.input_file_url, object_predictions)
annotation_import.wait_until_done()


def test_model_run_project_labels(model_run_with_all_project_labels,
model_run_predictions):
model_run = model_run_with_all_project_labels
model_run_exported_labels = model_run.export_labels(download=True)
labels_indexed_by_schema_id = {}
for label in model_run_exported_labels:
# assuming exported array of label 'objects' has only one label per data row... as usually is when there are no label revisions
schema_id = label['Label']['objects'][0]['schemaId']
labels_indexed_by_schema_id[schema_id] = label

assert (len(
labels_indexed_by_schema_id.keys())) == len(model_run_predictions)

# making sure the labels are in this model run are all labels uploaded to the project
# by comparing some 'immutable' attributes
for expected_label in model_run_predictions:
schema_id = expected_label['schemaId']
actual_label = labels_indexed_by_schema_id[schema_id]
assert actual_label['Label']['objects'][0]['title'] == expected_label[
'name']
assert actual_label['DataRow ID'] == expected_label['dataRow']['id']


def test_create_from_label_objects(model_run_with_model_run_data_rows,
object_predictions,
annotation_import_test_helpers):
Expand Down