diff --git a/labelbox/client.py b/labelbox/client.py index 87c85d732..4470190f9 100644 --- a/labelbox/client.py +++ b/labelbox/client.py @@ -183,6 +183,7 @@ def convert_value(value): endpoint = self.endpoint if not experimental else self.endpoint.replace( "/graphql", "/_gql") + try: request = { 'url': endpoint, diff --git a/labelbox/schema/data_row_metadata.py b/labelbox/schema/data_row_metadata.py index b7b4105e1..93ec2ad79 100644 --- a/labelbox/schema/data_row_metadata.py +++ b/labelbox/schema/data_row_metadata.py @@ -3,9 +3,12 @@ from copy import deepcopy from enum import Enum from itertools import chain -from typing import List, Optional, Dict, Union, Callable, Type, Any, Generator +import warnings + +from typing import List, Optional, Dict, Union, Callable, Type, Any, Generator, overload from pydantic import BaseModel, conlist, constr +from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds from labelbox.schema.ontology import SchemaId from labelbox.utils import _CamelCaseMixin, format_iso_datetime, format_iso_from_string @@ -601,13 +604,23 @@ def _batch_delete( items, batch_size=self._batch_size) + @overload def bulk_export(self, data_row_ids: List[str]) -> List[DataRowMetadata]: + pass + + @overload + def bulk_export(self, + data_row_ids: DataRowIdentifiers) -> List[DataRowMetadata]: + pass + + def bulk_export(self, data_row_ids) -> List[DataRowMetadata]: """ Exports metadata for a list of data rows >>> mdo.bulk_export([data_row.uid for data_row in data_rows]) Args: - data_row_ids: List of data data rows to fetch metadata for + data_row_ids: List of data data rows to fetch metadata for. This can be a list of strings or a DataRowIdentifiers object + DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class. Returns: A list of DataRowMetadata. There will be one DataRowMetadata for each data_row_id passed in. @@ -615,13 +628,20 @@ def bulk_export(self, data_row_ids: List[str]) -> List[DataRowMetadata]: Data rows without metadata will have empty `fields`. """ - if not len(data_row_ids): raise ValueError("Empty list passed") - def _bulk_export(_data_row_ids: List[str]) -> List[DataRowMetadata]: - query = """query dataRowCustomMetadataPyApi($dataRowIds: [ID!]!) { - dataRowCustomMetadata(where: {dataRowIds : $dataRowIds}) { + if isinstance(data_row_ids, + list) and len(data_row_ids) > 0 and isinstance( + data_row_ids[0], str): + data_row_ids = UniqueIds(data_row_ids) + warnings.warn("Using data row ids will be deprecated. Please use " + "UniqueIds or GlobalKeys instead.") + + def _bulk_export( + _data_row_ids: DataRowIdentifiers) -> List[DataRowMetadata]: + query = """query dataRowCustomMetadataPyApi($dataRowIdentifiers: DataRowCustomMetadataDataRowIdentifiersInput) { + dataRowCustomMetadata(where: {dataRowIdentifiers : $dataRowIdentifiers}) { dataRowId globalKey fields { @@ -633,8 +653,12 @@ def _bulk_export(_data_row_ids: List[str]) -> List[DataRowMetadata]: """ return self.parse_metadata( self._client.execute( - query, - {"dataRowIds": _data_row_ids})['dataRowCustomMetadata']) + query, { + "dataRowIdentifiers": { + "ids": [id for id in _data_row_ids], + "idType": _data_row_ids.id_type + } + })['dataRowCustomMetadata']) return _batch_operations(_bulk_export, data_row_ids, diff --git a/labelbox/schema/identifiables.py b/labelbox/schema/identifiables.py index a77704236..e6780128c 100644 --- a/labelbox/schema/identifiables.py +++ b/labelbox/schema/identifiables.py @@ -12,15 +12,32 @@ def __init__(self, iterable, id_type: IdType): id_type: The type of id used to identify a data row. """ self._iterable = iterable - self._index = 0 self._id_type = id_type + @property + def id_type(self): + return self._id_type + def __iter__(self): return iter(self._iterable) + def __getitem__(self, index): + if isinstance(index, slice): + ids = self._iterable[index] + return self.__class__(ids) # type: ignore + return self._iterable[index] + + def __len__(self): + return len(self._iterable) + def __repr__(self) -> str: return f"{self.__class__.__name__}({self._iterable})" + def __eq__(self, other: object) -> bool: + if not isinstance(other, Identifiables): + return False + return self._iterable == other._iterable and self._id_type == other._id_type + class UniqueIds(Identifiables): """ diff --git a/labelbox/schema/project.py b/labelbox/schema/project.py index 73680076b..6580576e2 100644 --- a/labelbox/schema/project.py +++ b/labelbox/schema/project.py @@ -1289,7 +1289,7 @@ def update_data_row_labeling_priority( project_param: self.uid, data_rows_param: { "ids": [id for id in data_rows], - "idType": data_rows._id_type, + "idType": data_rows.id_type, }, })["project"][method] @@ -1484,7 +1484,7 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): "queueId": task_queue_id, "dataRowIdentifiers": { "ids": [id for id in data_row_ids], - "idType": data_row_ids._id_type, + "idType": data_row_ids.id_type, }, }, timeout=180.0, diff --git a/tests/integration/test_data_row_metadata.py b/tests/integration/test_data_row_metadata.py index 3d2fc9224..bc9959a2b 100644 --- a/tests/integration/test_data_row_metadata.py +++ b/tests/integration/test_data_row_metadata.py @@ -7,6 +7,7 @@ from labelbox.exceptions import MalformedQueryException from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadata, DataRowMetadataKind, DeleteDataRowMetadata, \ DataRowMetadataOntology, _parse_metadata_schema +from labelbox.schema.identifiables import GlobalKeys, UniqueIds INVALID_SCHEMA_ID = "1" * 25 FAKE_SCHEMA_ID = "0" * 25 @@ -102,6 +103,16 @@ def test_bulk_export_datarow_metadata(data_row, mdo: DataRowMetadataOntology): assert exported[0].data_row_id == data_row.uid assert len([field for field in exported[0].fields]) == 3 + exported = mdo.bulk_export(UniqueIds([data_row.uid])) + assert exported[0].global_key == data_row.global_key + assert exported[0].data_row_id == data_row.uid + assert len([field for field in exported[0].fields]) == 3 + + exported = mdo.bulk_export(GlobalKeys([data_row.global_key])) + assert exported[0].global_key == data_row.global_key + assert exported[0].data_row_id == data_row.uid + assert len([field for field in exported[0].fields]) == 3 + def test_get_datarow_metadata_ontology(mdo): assert len(mdo.fields) diff --git a/tests/unit/test_unit_identifiables.py b/tests/unit/test_unit_identifiables.py index b52b3d4eb..c26fa13ba 100644 --- a/tests/unit/test_unit_identifiables.py +++ b/tests/unit/test_unit_identifiables.py @@ -5,14 +5,23 @@ def test_unique_ids(): ids = ["a", "b", "c"] identifiables = UniqueIds(ids) assert [i for i in identifiables] == ids - assert identifiables._id_type == "ID" + assert identifiables.id_type == "ID" + assert len(identifiables) == 3 def test_global_keys(): ids = ["a", "b", "c"] identifiables = GlobalKeys(ids) assert [i for i in identifiables] == ids - assert identifiables._id_type == "GKEY" + assert identifiables.id_type == "GKEY" + assert len(identifiables) == 3 + + +def test_index_access(): + ids = ["a", "b", "c"] + identifiables = GlobalKeys(ids) + assert identifiables[0] == "a" + assert identifiables[1:3] == GlobalKeys(["b", "c"]) def test_repr():