diff --git a/labelbox/__init__.py b/labelbox/__init__.py index a7263200e..23eb7d752 100644 --- a/labelbox/__init__.py +++ b/labelbox/__init__.py @@ -34,3 +34,4 @@ from labelbox.schema.queue_mode import QueueMode from labelbox.schema.task_queue import TaskQueue from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds +from labelbox.schema.identifiable import UniqueId, GlobalKey diff --git a/labelbox/schema/__init__.py b/labelbox/schema/__init__.py index d3bb1985e..dc9df500d 100644 --- a/labelbox/schema/__init__.py +++ b/labelbox/schema/__init__.py @@ -22,3 +22,4 @@ import labelbox.schema.iam_integration import labelbox.schema.media_type import labelbox.schema.identifiables +import labelbox.schema.identifiable diff --git a/labelbox/schema/id_type.py b/labelbox/schema/id_type.py new file mode 100644 index 000000000..8bf00b300 --- /dev/null +++ b/labelbox/schema/id_type.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class IdType(str, Enum): + """ + The type of id used to identify a data row. + + Currently supported types are: + - DataRowId: The id assigned to a data row by Labelbox. + - GlobalKey: The id assigned to a data row by the user. + """ + DataRowId = "ID" + GlobalKey = "GKEY" diff --git a/labelbox/schema/identifiable.py b/labelbox/schema/identifiable.py index 33d4e0d11..90e0716ce 100644 --- a/labelbox/schema/identifiable.py +++ b/labelbox/schema/identifiable.py @@ -1,5 +1,7 @@ -from abc import ABC, abstractmethod -from typing import List, Union +from abc import ABC +from typing import Union + +from labelbox.schema.id_type import IdType class Identifiable(ABC): @@ -7,32 +9,44 @@ class Identifiable(ABC): Base class for any object representing a unique identifier. """ - def __init__(self, key: str): + def __init__(self, key: str, id_type: IdType): self._key = key + self._id_type = id_type @property def key(self): - return self.key + return self._key + + @property + def id_type(self): + return self._id_type def __eq__(self, other): - return other.key == self.key + return other.key == self.key and other.id_type == self.id_type def __hash__(self): - hash(self.key) + return hash((self.key, self.id_type)) def __str__(self): - return self.key.__str__() + return f"{self.id_type}:{self.key}" class UniqueId(Identifiable): """ Represents a unique, internally generated id. """ - pass + + def __init__(self, key: str): + super().__init__(key, IdType.DataRowId) class GlobalKey(Identifiable): """ Represents a user generated id. """ - pass + + def __init__(self, key: str): + super().__init__(key, IdType.GlobalKey) + + +DataRowIdentifier = Union[UniqueId, GlobalKey] diff --git a/labelbox/schema/identifiables.py b/labelbox/schema/identifiables.py index ba4265b17..a77704236 100644 --- a/labelbox/schema/identifiables.py +++ b/labelbox/schema/identifiables.py @@ -1,17 +1,6 @@ -from enum import Enum from typing import List, Union - -class IdType(str, Enum): - """ - The type of id used to identify a data row. - - Currently supported types are: - - DataRowId: The id assigned to a data row by Labelbox. - - GlobalKey: The id assigned to a data row by the user. - """ - DataRowId = "ID" - GlobalKey = "GKEY" +from labelbox.schema.id_type import IdType class Identifiables: diff --git a/labelbox/schema/project.py b/labelbox/schema/project.py index 3bcd6333c..73680076b 100644 --- a/labelbox/schema/project.py +++ b/labelbox/schema/project.py @@ -1,11 +1,12 @@ import json import logging +from string import Template import time import warnings from collections import namedtuple from datetime import datetime, timezone from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union, overload from urllib.parse import urlparse import requests @@ -28,6 +29,8 @@ from labelbox.schema.export_filters import ProjectExportFilters, validate_datetime, build_filters from labelbox.schema.export_params import ProjectExportParams from labelbox.schema.export_task import ExportTask +from labelbox.schema.id_type import IdType +from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds from labelbox.schema.media_type import MediaType from labelbox.schema.queue_mode import QueueMode @@ -43,9 +46,38 @@ except ImportError: pass +DataRowPriority = int +LabelingParameterOverrideInput = Tuple[Union[DataRow, DataRowIdentifier], + DataRowPriority] + logger = logging.getLogger(__name__) +def validate_labeling_parameter_overrides( + data: List[LabelingParameterOverrideInput]) -> None: + for idx, row in enumerate(data): + if len(row) < 2: + raise TypeError( + f"Data must be a list of tuples each containing two elements: a DataRow or a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}" + ) + data_row_identifier = row[0] + priority = row[1] + valid_types = (Entity.DataRow, UniqueId, GlobalKey) + if not isinstance(data_row_identifier, valid_types): + raise TypeError( + f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found {type(data_row_identifier)} for data_row_identifier {data_row_identifier}" + ) + + if not isinstance(priority, int): + if isinstance(data_row_identifier, Entity.DataRow): + id = data_row_identifier.uid + else: + id = data_row_identifier + raise TypeError( + f"Priority must be an int. Found {type(priority)} for data_row_identifier {id}" + ) + + class Project(DbObject, Updateable, Deletable): """ A Project is a container that includes a labeling frontend, an ontology, datasets and labels. @@ -1129,36 +1161,25 @@ def get_queue_mode(self) -> "QueueMode": else: raise ValueError("Status not known") - def validate_labeling_parameter_overrides(self, data) -> None: - for idx, row in enumerate(data): - if len(row) < 2: - raise TypeError( - f"Data must be a list of tuples containing a DataRow and priority (int). Found {len(row)} items. Index: {idx}" - ) - data_row = row[0] - priority = row[1] - if not isinstance(data_row, Entity.DataRow): - raise TypeError( - f"data_row should be be of type DataRow. Found {type(data_row)}. Index: {idx}" - ) - - if not isinstance(priority, int): - raise TypeError( - f"Priority must be an int. Found {type(priority)} for data_row {data_row}. Index: {idx}" - ) - - def set_labeling_parameter_overrides(self, data) -> bool: + def set_labeling_parameter_overrides( + self, data: List[LabelingParameterOverrideInput]) -> bool: """ Adds labeling parameter overrides to this project. See information on priority here: https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system >>> project.set_labeling_parameter_overrides([ - >>> (data_row_1, 2), (data_row_2, 1)]) + >>> (data_row_id1, 2), (data_row_id2, 1)]) + or + >>> project.set_labeling_parameter_overrides([ + >>> (data_row_gk1, 2), (data_row_gk2, 1)]) Args: data (iterable): An iterable of tuples. Each tuple must contain - (DataRow, priority) for the new override. + either (DataRow, DataRowPriority) + or (DataRowIdentifier, priority) for the new override. + DataRowIdentifier is an object representing a data row id or a global key. A DataIdentifier object can be a UniqueIds or GlobalKeys class. + NOTE - passing whole DatRow is deprecated. Please use a DataRowIdentifier instead. Priority: * Data will be labeled in priority order. @@ -1174,15 +1195,31 @@ def set_labeling_parameter_overrides(self, data) -> bool: bool, indicates if the operation was a success. """ data = [t[:2] for t in data] - self.validate_labeling_parameter_overrides(data) - data_str = ",\n".join("{dataRow: {id: \"%s\"}, priority: %d }" % - (data_row.uid, priority) - for data_row, priority in data) - id_param = "projectId" - query_str = """mutation SetLabelingParameterOverridesPyApi($%s: ID!){ - project(where: { id: $%s }) {setLabelingParameterOverrides - (data: [%s]) {success}}} """ % (id_param, id_param, data_str) - res = self.client.execute(query_str, {id_param: self.uid}) + validate_labeling_parameter_overrides(data) + + template = Template( + """mutation SetLabelingParameterOverridesPyApi($$projectId: ID!) + {project(where: { id: $$projectId }) + {setLabelingParameterOverrides + (dataWithDataRowIdentifiers: [$dataWithDataRowIdentifiers]) + {success}}} + """) + + data_rows_with_identifiers = "" + for data_row, priority in data: + if isinstance(data_row, DataRow): + data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.uid}\", idType: {IdType.DataRowId}}}, priority: {priority}}}," + elif isinstance(data_row, UniqueId) or isinstance( + data_row, GlobalKey): + data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.key}\", idType: {data_row.id_type}}}, priority: {priority}}}," + else: + raise TypeError( + f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}." + ) + + query_str = template.substitute( + dataWithDataRowIdentifiers=data_rows_with_identifiers) + res = self.client.execute(query_str, {"projectId": self.uid}) return res["project"]["setLabelingParameterOverrides"]["success"] @overload diff --git a/tests/integration/test_labeling_parameter_overrides.py b/tests/integration/test_labeling_parameter_overrides.py index 20cb6d000..51c56353c 100644 --- a/tests/integration/test_labeling_parameter_overrides.py +++ b/tests/integration/test_labeling_parameter_overrides.py @@ -1,5 +1,6 @@ import pytest from labelbox import DataRow +from labelbox.schema.identifiable import GlobalKey, UniqueId from labelbox.schema.identifiables import GlobalKeys, UniqueIds @@ -27,17 +28,36 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): for override in updated_overrides: assert isinstance(override.data_row(), DataRow) + data = [(UniqueId(data_rows[0].uid), 1, 2), (UniqueId(data_rows[1].uid), 2), + (UniqueId(data_rows[2].uid), 3)] + success = project.set_labeling_parameter_overrides(data) + assert success + updated_overrides = list(project.labeling_parameter_overrides()) + assert len(updated_overrides) == 3 + assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1} + assert {o.priority for o in updated_overrides} == {1, 2, 3} + + data = [(GlobalKey(data_rows[0].global_key), 2, 2), + (GlobalKey(data_rows[1].global_key), 3, 3), + (GlobalKey(data_rows[2].global_key), 4)] + success = project.set_labeling_parameter_overrides(data) + assert success + updated_overrides = list(project.labeling_parameter_overrides()) + assert len(updated_overrides) == 3 + assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1} + assert {o.priority for o in updated_overrides} == {2, 3, 4} + with pytest.raises(TypeError) as exc_info: data = [(data_rows[2], "a_string", 3)] project.set_labeling_parameter_overrides(data) assert str(exc_info.value) == \ - f"Priority must be an int. Found for data_row {data_rows[2]}. Index: 0" + f"Priority must be an int. Found for data_row_identifier {data_rows[2].uid}" with pytest.raises(TypeError) as exc_info: data = [(data_rows[2].uid, 1)] project.set_labeling_parameter_overrides(data) assert str(exc_info.value) == \ - "data_row should be be of type DataRow. Found . Index: 0" + f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found for data_row_identifier {data_rows[2].uid}" def test_set_labeling_priority(consensus_project_with_batch): diff --git a/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py b/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py new file mode 100644 index 000000000..f9f9a0959 --- /dev/null +++ b/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py @@ -0,0 +1,37 @@ +import pytest +from unittest.mock import MagicMock + +from labelbox.schema.data_row import DataRow +from labelbox.schema.identifiable import GlobalKey, UniqueId +from labelbox.schema.project import validate_labeling_parameter_overrides + + +def test_validate_labeling_parameter_overrides_valid_data(): + mock_data_row = MagicMock(spec=DataRow) + mock_data_row.uid = "abc" + data = [(mock_data_row, 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)] + validate_labeling_parameter_overrides(data) + + +def test_validate_labeling_parameter_overrides_invalid_data(): + data = [("abc", 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)] + with pytest.raises(TypeError): + validate_labeling_parameter_overrides(data) + + +def test_validate_labeling_parameter_overrides_invalid_priority(): + mock_data_row = MagicMock(spec=DataRow) + mock_data_row.uid = "abc" + data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2), + (GlobalKey("hij"), 3)] + with pytest.raises(TypeError): + validate_labeling_parameter_overrides(data) + + +def test_validate_labeling_parameter_overrides_invalid_tuple_length(): + mock_data_row = MagicMock(spec=DataRow) + mock_data_row.uid = "abc" + data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2), + (GlobalKey("hij"))] + with pytest.raises(TypeError): + validate_labeling_parameter_overrides(data)