diff --git a/labelbox/schema/project.py b/labelbox/schema/project.py index 6a2b5066e..f3bc339c0 100644 --- a/labelbox/schema/project.py +++ b/labelbox/schema/project.py @@ -1166,6 +1166,60 @@ def set_labeling_parameter_overrides(self, data) -> bool: res = self.client.execute(query_str, {id_param: self.uid}) return res["project"]["setLabelingParameterOverrides"]["success"] + def update_data_row_labeling_priority( + self, + data_rows: List[str], + priority: int, + ) -> bool: + """ + Updates labeling parameter overrides to this project in bulk. This method allows up to 1 million data rows to be + updated at once. + + See information on priority here: + https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system + + Args: + data_rows (iterable): An iterable of data row ids. + priority (int): Priority for the new override. See above for more information. + + Returns: + bool, indicates if the operation was a success. + """ + + method = "createQueuePriorityUpdateTask" + priority_param = "priority" + project_param = "projectId" + data_rows_param = "dataRowIds" + query_str = """mutation %sPyApi( + $%s: Int! + $%s: ID! + $%s: [ID!] + ) { + project(where: { id: $%s }) { + %s( + data: { priority: $%s, dataRowIds: $%s } + ) { + taskId + } + } + } + """ % (method, priority_param, project_param, data_rows_param, + project_param, method, priority_param, data_rows_param) + res = self.client.execute( + query_str, { + priority_param: priority, + project_param: self.uid, + data_rows_param: data_rows + })["project"][method] + + task_id = res['taskId'] + + task = self._wait_for_task(task_id) + if task.status != "COMPLETE": + raise LabelboxError(f"Priority was not updated successfully: " + + json.dumps(task.errors)) + return True + def upsert_review_queue(self, quota_factor) -> None: """ Sets the the proportion of total assets in a project to review. diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 9fa22d014..909f1446d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -11,7 +11,7 @@ import pytest import requests -from labelbox import Dataset +from labelbox import Dataset, DataRow from labelbox import LabelingFrontend from labelbox import OntologyBuilder, Tool, Option, Classification, MediaType from labelbox.orm import query @@ -167,6 +167,29 @@ def consensus_project(client, rand_gen): project.delete() +@pytest.fixture +def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen, + image_url): + project = consensus_project + dataset = initial_dataset + + task = dataset.create_data_rows([{DataRow.row_data: image_url}] * 3) + task.wait_till_done() + assert task.status == "COMPLETE" + + data_rows = list(dataset.data_rows()) + assert len(data_rows) == 3 + + batch = project.create_batch( + rand_gen(str), + data_rows, # sample of data row objects + 5 # priority between 1(Highest) - 5(lowest) + ) + + yield [project, batch, data_rows] + batch.delete() + + @pytest.fixture def dataset(client, rand_gen): dataset = client.create_dataset(name=rand_gen(str)) diff --git a/tests/integration/test_labeling_parameter_overrides.py b/tests/integration/test_labeling_parameter_overrides.py index 25b03b48f..46b46b2b1 100644 --- a/tests/integration/test_labeling_parameter_overrides.py +++ b/tests/integration/test_labeling_parameter_overrides.py @@ -2,23 +2,8 @@ from labelbox import DataRow -def test_labeling_parameter_overrides(consensus_project, initial_dataset, - rand_gen, image_url): - project = consensus_project - dataset = initial_dataset - - task = dataset.create_data_rows([{DataRow.row_data: image_url}] * 3) - task.wait_till_done() - assert task.status == "COMPLETE" - - data_rows = list(dataset.data_rows()) - assert len(data_rows) == 3 - - project.create_batch( - rand_gen(str), - data_rows, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) +def test_labeling_parameter_overrides(consensus_project_with_batch): + [project, _, data_rows] = consensus_project_with_batch init_labeling_parameter_overrides = list( project.labeling_parameter_overrides()) @@ -45,10 +30,27 @@ def test_labeling_parameter_overrides(consensus_project, initial_dataset, data = [(data_rows[2], "a_string", 3)] project.set_labeling_parameter_overrides(data) assert str(exc_info.value) == \ - f"Priority must be an int. Found for data_row {data_rows[2]}. Index: 0" + f"Priority must be an int. Found for data_row {data_rows[2]}. Index: 0" with pytest.raises(TypeError) as exc_info: data = [(data_rows[2].uid, 1)] project.set_labeling_parameter_overrides(data) assert str(exc_info.value) == \ - "data_row should be be of type DataRow. Found . Index: 0" + "data_row should be be of type DataRow. Found . Index: 0" + + +def test_set_labeling_priority(consensus_project_with_batch): + [project, _, data_rows] = consensus_project_with_batch + + init_labeling_parameter_overrides = list( + project.labeling_parameter_overrides()) + assert len(init_labeling_parameter_overrides) == 3 + assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5} + + data = [data_row.uid for data_row in data_rows] + success = project.update_data_row_labeling_priority(data, 1) + assert success + + updated_overrides = list(project.labeling_parameter_overrides()) + assert len(updated_overrides) == 3 + assert {o.priority for o in updated_overrides} == {1, 1, 1}