Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions labelbox/schema/identifiables.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def __init__(self, iterable, id_type: IdType):
def __iter__(self):
return iter(self._iterable)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._iterable})"


class UniqueIds(Identifiables):
"""
Expand Down
35 changes: 30 additions & 5 deletions labelbox/schema/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,10 +1185,26 @@ def set_labeling_parameter_overrides(self, data) -> bool:
res = self.client.execute(query_str, {id_param: self.uid})
return res["project"]["setLabelingParameterOverrides"]["success"]

@overload
def update_data_row_labeling_priority(
self,
data_rows: DataRowIdentifiers,
priority: int,
) -> bool:
pass

@overload
def update_data_row_labeling_priority(
self,
data_rows: List[str],
priority: int,
) -> bool:
pass

def update_data_row_labeling_priority(
self,
data_rows,
priority: int,
) -> bool:
"""
Updates labeling parameter overrides to this project in bulk. This method allows up to 1 million data rows to be
Expand All @@ -1198,25 +1214,31 @@ def update_data_row_labeling_priority(
https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system

Args:
data_rows (iterable): An iterable of data row ids.
data_rows: a list of data row ids to update priorities for. This can be a list of strings or a DataRowIdentifiers object
DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
priority (int): Priority for the new override. See above for more information.

Returns:
bool, indicates if the operation was a success.
"""

if isinstance(data_rows, list):
data_rows = UniqueIds(data_rows)
warnings.warn("Using data row ids will be deprecated. Please use "
"UniqueIds or GlobalKeys instead.")

method = "createQueuePriorityUpdateTask"
priority_param = "priority"
project_param = "projectId"
data_rows_param = "dataRowIds"
data_rows_param = "dataRowIdentifiers"
query_str = """mutation %sPyApi(
$%s: Int!
$%s: ID!
$%s: [ID!]
$%s: QueuePriorityUpdateDataRowIdentifiersInput
) {
project(where: { id: $%s }) {
%s(
data: { priority: $%s, dataRowIds: $%s }
data: { priority: $%s, dataRowIdentifiers: $%s }
) {
taskId
}
Expand All @@ -1228,7 +1250,10 @@ def update_data_row_labeling_priority(
query_str, {
priority_param: priority,
project_param: self.uid,
data_rows_param: data_rows
data_rows_param: {
"ids": [id for id in data_rows],
"idType": data_rows._id_type,
},
})["project"][method]

task_id = res['taskId']
Expand Down
9 changes: 7 additions & 2 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,18 @@ def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen,
project = consensus_project
dataset = initial_dataset

task = dataset.create_data_rows([{DataRow.row_data: image_url}] * 3)
data_rows = []
for _ in range(3):
data_rows.append({
DataRow.row_data: image_url,
DataRow.global_key: str(uuid.uuid4())
})
task = dataset.create_data_rows(data_rows)
task.wait_till_done()
assert task.status == "COMPLETE"

data_rows = list(dataset.data_rows())
assert len(data_rows) == 3

batch = project.create_batch(
rand_gen(str),
data_rows, # sample of data row objects
Expand Down
20 changes: 17 additions & 3 deletions tests/integration/test_labeling_parameter_overrides.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from labelbox import DataRow
from labelbox.schema.identifiables import GlobalKeys, UniqueIds


def test_labeling_parameter_overrides(consensus_project_with_batch):
Expand Down Expand Up @@ -49,8 +50,21 @@ def test_set_labeling_priority(consensus_project_with_batch):

data = [data_row.uid for data_row in data_rows]
success = project.update_data_row_labeling_priority(data, 1)
lo = list(project.labeling_parameter_overrides())
assert success
assert len(lo) == 3
assert {o.priority for o in lo} == {1, 1, 1}

updated_overrides = list(project.labeling_parameter_overrides())
assert len(updated_overrides) == 3
assert {o.priority for o in updated_overrides} == {1, 1, 1}
data = [data_row.uid for data_row in data_rows]
success = project.update_data_row_labeling_priority(UniqueIds(data), 2)
lo = list(project.labeling_parameter_overrides())
assert success
assert len(lo) == 3
assert {o.priority for o in lo} == {2, 2, 2}

data = [data_row.global_key for data_row in data_rows]
success = project.update_data_row_labeling_priority(GlobalKeys(data), 3)
lo = list(project.labeling_parameter_overrides())
assert success
assert len(lo) == 3
assert {o.priority for o in lo} == {3, 3, 3}
10 changes: 10 additions & 0 deletions tests/unit/test_unit_identifiables.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,13 @@ def test_global_keys():
identifiables = GlobalKeys(ids)
assert [i for i in identifiables] == ids
assert identifiables._id_type == "GKEY"


def test_repr():
ids = ["a", "b", "c"]
identifiables = GlobalKeys(ids)
assert repr(identifiables) == "GlobalKeys(['a', 'b', 'c'])"

ids = ["a", "b", "c"]
identifiables = UniqueIds(ids)
assert repr(identifiables) == "UniqueIds(['a', 'b', 'c'])"