Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions labelbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@
from labelbox.schema.queue_mode import QueueMode
from labelbox.schema.task_queue import TaskQueue
from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds
from labelbox.schema.identifiable import UniqueId, GlobalKey
1 change: 1 addition & 0 deletions labelbox/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
import labelbox.schema.iam_integration
import labelbox.schema.media_type
import labelbox.schema.identifiables
import labelbox.schema.identifiable
13 changes: 13 additions & 0 deletions labelbox/schema/id_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from enum import Enum


class IdType(str, Enum):
"""
The type of id used to identify a data row.

Currently supported types are:
- DataRowId: The id assigned to a data row by Labelbox.
- GlobalKey: The id assigned to a data row by the user.
"""
DataRowId = "ID"
GlobalKey = "GKEY"
32 changes: 23 additions & 9 deletions labelbox/schema/identifiable.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,52 @@
from abc import ABC, abstractmethod
from typing import List, Union
from abc import ABC
from typing import Union

from labelbox.schema.id_type import IdType


class Identifiable(ABC):
"""
Base class for any object representing a unique identifier.
"""

def __init__(self, key: str):
def __init__(self, key: str, id_type: IdType):
self._key = key
self._id_type = id_type

@property
def key(self):
return self.key
return self._key

@property
def id_type(self):
return self._id_type

def __eq__(self, other):
return other.key == self.key
return other.key == self.key and other.id_type == self.id_type

def __hash__(self):
hash(self.key)
return hash((self.key, self.id_type))

def __str__(self):
return self.key.__str__()
return f"{self.id_type}:{self.key}"


class UniqueId(Identifiable):
"""
Represents a unique, internally generated id.
"""
pass

def __init__(self, key: str):
super().__init__(key, IdType.DataRowId)


class GlobalKey(Identifiable):
"""
Represents a user generated id.
"""
pass

def __init__(self, key: str):
super().__init__(key, IdType.GlobalKey)


DataRowIdentifier = Union[UniqueId, GlobalKey]
13 changes: 1 addition & 12 deletions labelbox/schema/identifiables.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
from enum import Enum
from typing import List, Union


class IdType(str, Enum):
"""
The type of id used to identify a data row.

Currently supported types are:
- DataRowId: The id assigned to a data row by Labelbox.
- GlobalKey: The id assigned to a data row by the user.
"""
DataRowId = "ID"
GlobalKey = "GKEY"
from labelbox.schema.id_type import IdType


class Identifiables:
Expand Down
99 changes: 68 additions & 31 deletions labelbox/schema/project.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import logging
from string import Template
import time
import warnings
from collections import namedtuple
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union, overload
from urllib.parse import urlparse

import requests
Expand All @@ -28,6 +29,8 @@
from labelbox.schema.export_filters import ProjectExportFilters, validate_datetime, build_filters
from labelbox.schema.export_params import ProjectExportParams
from labelbox.schema.export_task import ExportTask
from labelbox.schema.id_type import IdType
from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId
from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
from labelbox.schema.media_type import MediaType
from labelbox.schema.queue_mode import QueueMode
Expand All @@ -43,9 +46,38 @@
except ImportError:
pass

DataRowPriority = int
LabelingParameterOverrideInput = Tuple[Union[DataRow, DataRowIdentifier],
DataRowPriority]

logger = logging.getLogger(__name__)


def validate_labeling_parameter_overrides(
data: List[LabelingParameterOverrideInput]) -> None:
for idx, row in enumerate(data):
if len(row) < 2:
raise TypeError(
f"Data must be a list of tuples each containing two elements: a DataRow or a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}"
)
data_row_identifier = row[0]
priority = row[1]
valid_types = (Entity.DataRow, UniqueId, GlobalKey)
if not isinstance(data_row_identifier, valid_types):
raise TypeError(
f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found {type(data_row_identifier)} for data_row_identifier {data_row_identifier}"
)

if not isinstance(priority, int):
if isinstance(data_row_identifier, Entity.DataRow):
id = data_row_identifier.uid
else:
id = data_row_identifier
raise TypeError(
f"Priority must be an int. Found {type(priority)} for data_row_identifier {id}"
)


class Project(DbObject, Updateable, Deletable):
""" A Project is a container that includes a labeling frontend, an ontology,
datasets and labels.
Expand Down Expand Up @@ -1129,36 +1161,25 @@ def get_queue_mode(self) -> "QueueMode":
else:
raise ValueError("Status not known")

def validate_labeling_parameter_overrides(self, data) -> None:
for idx, row in enumerate(data):
if len(row) < 2:
raise TypeError(
f"Data must be a list of tuples containing a DataRow and priority (int). Found {len(row)} items. Index: {idx}"
)
data_row = row[0]
priority = row[1]
if not isinstance(data_row, Entity.DataRow):
raise TypeError(
f"data_row should be be of type DataRow. Found {type(data_row)}. Index: {idx}"
)

if not isinstance(priority, int):
raise TypeError(
f"Priority must be an int. Found {type(priority)} for data_row {data_row}. Index: {idx}"
)

def set_labeling_parameter_overrides(self, data) -> bool:
def set_labeling_parameter_overrides(
self, data: List[LabelingParameterOverrideInput]) -> bool:
""" Adds labeling parameter overrides to this project.

See information on priority here:
https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system

>>> project.set_labeling_parameter_overrides([
>>> (data_row_1, 2), (data_row_2, 1)])
>>> (data_row_id1, 2), (data_row_id2, 1)])
or
>>> project.set_labeling_parameter_overrides([
>>> (data_row_gk1, 2), (data_row_gk2, 1)])

Args:
data (iterable): An iterable of tuples. Each tuple must contain
(DataRow, priority<int>) for the new override.
either (DataRow, DataRowPriority<int>)
or (DataRowIdentifier, priority<int>) for the new override.
DataRowIdentifier is an object representing a data row id or a global key. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
NOTE - passing whole DatRow is deprecated. Please use a DataRowIdentifier instead.

Priority:
* Data will be labeled in priority order.
Expand All @@ -1174,15 +1195,31 @@ def set_labeling_parameter_overrides(self, data) -> bool:
bool, indicates if the operation was a success.
"""
data = [t[:2] for t in data]
self.validate_labeling_parameter_overrides(data)
data_str = ",\n".join("{dataRow: {id: \"%s\"}, priority: %d }" %
(data_row.uid, priority)
for data_row, priority in data)
id_param = "projectId"
query_str = """mutation SetLabelingParameterOverridesPyApi($%s: ID!){
project(where: { id: $%s }) {setLabelingParameterOverrides
(data: [%s]) {success}}} """ % (id_param, id_param, data_str)
res = self.client.execute(query_str, {id_param: self.uid})
validate_labeling_parameter_overrides(data)

template = Template(
"""mutation SetLabelingParameterOverridesPyApi($$projectId: ID!)
{project(where: { id: $$projectId })
{setLabelingParameterOverrides
(dataWithDataRowIdentifiers: [$dataWithDataRowIdentifiers])
{success}}}
""")

data_rows_with_identifiers = ""
for data_row, priority in data:
if isinstance(data_row, DataRow):
data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.uid}\", idType: {IdType.DataRowId}}}, priority: {priority}}},"
elif isinstance(data_row, UniqueId) or isinstance(
data_row, GlobalKey):
data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.key}\", idType: {data_row.id_type}}}, priority: {priority}}},"
else:
raise TypeError(
f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}."
)

query_str = template.substitute(
dataWithDataRowIdentifiers=data_rows_with_identifiers)
res = self.client.execute(query_str, {"projectId": self.uid})
return res["project"]["setLabelingParameterOverrides"]["success"]

@overload
Expand Down
24 changes: 22 additions & 2 deletions tests/integration/test_labeling_parameter_overrides.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from labelbox import DataRow
from labelbox.schema.identifiable import GlobalKey, UniqueId
from labelbox.schema.identifiables import GlobalKeys, UniqueIds


Expand Down Expand Up @@ -27,17 +28,36 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
for override in updated_overrides:
assert isinstance(override.data_row(), DataRow)

data = [(UniqueId(data_rows[0].uid), 1, 2), (UniqueId(data_rows[1].uid), 2),
(UniqueId(data_rows[2].uid), 3)]
success = project.set_labeling_parameter_overrides(data)
assert success
updated_overrides = list(project.labeling_parameter_overrides())
assert len(updated_overrides) == 3
assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1}
assert {o.priority for o in updated_overrides} == {1, 2, 3}

data = [(GlobalKey(data_rows[0].global_key), 2, 2),
(GlobalKey(data_rows[1].global_key), 3, 3),
(GlobalKey(data_rows[2].global_key), 4)]
success = project.set_labeling_parameter_overrides(data)
assert success
updated_overrides = list(project.labeling_parameter_overrides())
assert len(updated_overrides) == 3
assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1}
assert {o.priority for o in updated_overrides} == {2, 3, 4}

with pytest.raises(TypeError) as exc_info:
data = [(data_rows[2], "a_string", 3)]
project.set_labeling_parameter_overrides(data)
assert str(exc_info.value) == \
f"Priority must be an int. Found <class 'str'> for data_row {data_rows[2]}. Index: 0"
f"Priority must be an int. Found <class 'str'> for data_row_identifier {data_rows[2].uid}"

with pytest.raises(TypeError) as exc_info:
data = [(data_rows[2].uid, 1)]
project.set_labeling_parameter_overrides(data)
assert str(exc_info.value) == \
"data_row should be be of type DataRow. Found <class 'str'>. Index: 0"
f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found <class 'str'> for data_row_identifier {data_rows[2].uid}"


def test_set_labeling_priority(consensus_project_with_batch):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
from unittest.mock import MagicMock

from labelbox.schema.data_row import DataRow
from labelbox.schema.identifiable import GlobalKey, UniqueId
from labelbox.schema.project import validate_labeling_parameter_overrides


def test_validate_labeling_parameter_overrides_valid_data():
mock_data_row = MagicMock(spec=DataRow)
mock_data_row.uid = "abc"
data = [(mock_data_row, 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
validate_labeling_parameter_overrides(data)


def test_validate_labeling_parameter_overrides_invalid_data():
data = [("abc", 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
with pytest.raises(TypeError):
validate_labeling_parameter_overrides(data)


def test_validate_labeling_parameter_overrides_invalid_priority():
mock_data_row = MagicMock(spec=DataRow)
mock_data_row.uid = "abc"
data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2),
(GlobalKey("hij"), 3)]
with pytest.raises(TypeError):
validate_labeling_parameter_overrides(data)


def test_validate_labeling_parameter_overrides_invalid_tuple_length():
mock_data_row = MagicMock(spec=DataRow)
mock_data_row.uid = "abc"
data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2),
(GlobalKey("hij"))]
with pytest.raises(TypeError):
validate_labeling_parameter_overrides(data)