From 903fdfc4fce9ac3fc0dcf46b1144107947828251 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Wed, 13 Sep 2023 15:06:37 -0700 Subject: [PATCH] Add sdk method to get data row by global key --- labelbox/client.py | 17 +++++++++++++++++ tests/integration/conftest.py | 19 ++++++++++++++++++- tests/integration/test_data_rows.py | 7 +++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/labelbox/client.py b/labelbox/client.py index f6578cb9c..de7409b5e 100644 --- a/labelbox/client.py +++ b/labelbox/client.py @@ -24,6 +24,7 @@ from labelbox.pagination import PaginatedCollection from labelbox.schema.data_row_metadata import DataRowMetadataOntology from labelbox.schema.dataset import Dataset +from labelbox.schema.data_row import DataRow from labelbox.schema.enums import CollectionJobStatus from labelbox.schema.iam_integration import IAMIntegration from labelbox.schema import role @@ -430,6 +431,7 @@ def _get_single(self, db_object_type, uid): of the given type for the given ID. """ query_str, params = query.get_single(db_object_type, uid) + res = self.execute(query_str, params) res = res and res.get(utils.camel_case(db_object_type.type_name())) if res is None: @@ -727,6 +729,21 @@ def get_data_row(self, data_row_id): return self._get_single(Entity.DataRow, data_row_id) + def get_data_row_by_global_key(self, global_key: str) -> DataRow: + """ + Returns: DataRow: returns a single data row given the global key + """ + + res = self.get_data_row_ids_for_global_keys([global_key]) + if res['status'] != "SUCCESS": + raise labelbox.exceptions.MalformedQueryException(res['errors'][0]) + if len(res['results']) == 0: + raise labelbox.exceptions.ResourceNotFoundError( + Entity.DataRow, {global_key: global_key}) + data_row_id = res['results'][0] + + return self.get_data_row(data_row_id) + def get_data_row_metadata_ontology(self) -> DataRowMetadataOntology: """ diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f14c45c56..0a8396a91 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -195,11 +195,12 @@ def small_dataset(dataset: Dataset): @pytest.fixture def data_row(dataset, image_url, rand_gen): + global_key = f"global-key-{rand_gen(str)}" task = dataset.create_data_rows([ { "row_data": image_url, "external_id": "my-image", - "global_key": f"global-key-{rand_gen(str)}" + "global_key": global_key }, ]) task.wait_till_done() @@ -208,6 +209,22 @@ def data_row(dataset, image_url, rand_gen): dr.delete() +@pytest.fixture +def data_row_and_global_key(dataset, image_url, rand_gen): + global_key = f"global-key-{rand_gen(str)}" + task = dataset.create_data_rows([ + { + "row_data": image_url, + "external_id": "my-image", + "global_key": global_key + }, + ]) + task.wait_till_done() + dr = dataset.data_rows().get_one() + yield dr, global_key + dr.delete() + + # can be used with # @pytest.mark.parametrize('data_rows', [], indirect=True) # if omitted, count defaults to 1 diff --git a/tests/integration/test_data_rows.py b/tests/integration/test_data_rows.py index 75e04863b..6671beb60 100644 --- a/tests/integration/test_data_rows.py +++ b/tests/integration/test_data_rows.py @@ -118,6 +118,13 @@ def make_metadata_fields_dict(): return fields +def test_get_data_row_by_global_key(data_row_and_global_key, client, rand_gen): + _, global_key = data_row_and_global_key + data_row = client.get_data_row_by_global_key(global_key) + assert type(data_row) == DataRow + assert data_row.global_key == global_key + + def test_get_data_row(data_row, client): assert client.get_data_row(data_row.uid)