From 5626b4e4396c62db4a51f93941bec15b6c341091 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Wed, 3 Jan 2024 16:38:37 -0800 Subject: [PATCH] Add method to CatalogSlice to get data row identifiers (both uids and global keys) Also: Make return typed Add deprecation warning to get_data_row_ids --- labelbox/schema/slice.py | 64 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/labelbox/schema/slice.py b/labelbox/schema/slice.py index 58d5d3b41..15fff2277 100644 --- a/labelbox/schema/slice.py +++ b/labelbox/schema/slice.py @@ -1,12 +1,13 @@ -from typing import Optional, List -from labelbox.exceptions import ResourceNotFoundError +from dataclasses import dataclass +from typing import Optional +import warnings from labelbox.orm.db_object import DbObject, experimental -from labelbox.orm.model import Entity, Field +from labelbox.orm.model import Field from labelbox.pagination import PaginatedCollection from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params from labelbox.schema.export_task import ExportTask +from labelbox.schema.identifiable import GlobalKey, UniqueId from labelbox.schema.task import Task -from labelbox.schema.user import User class Slice(DbObject): @@ -34,13 +35,27 @@ class CatalogSlice(Slice): Represents a Slice used for filtering data rows in Catalog. """ + @dataclass + class DataRowIdAndGlobalKey: + id: UniqueId + global_key: Optional[GlobalKey] + + def __init__(self, id: str, global_key: Optional[str]): + self.id = UniqueId(id) + self.global_key = GlobalKey(global_key) if global_key else None + def get_data_row_ids(self) -> PaginatedCollection: """ Fetches all data row ids that match this Slice Returns: - A PaginatedCollection of data row ids + A PaginatedCollection of mapping of data row ids to global keys """ + + warnings.warn( + "get_data_row_ids will be deprecated. Use get_data_row_identifiers instead" + ) + query_str = """ query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) { getDataRowIdsBySavedQuery(input: { @@ -65,6 +80,45 @@ def get_data_row_ids(self) -> PaginatedCollection: obj_class=lambda _, data_row_id: data_row_id, cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor']) + def get_data_row_identifiers(self) -> PaginatedCollection: + """ + Fetches all data row ids and global keys (where defined) that match this Slice + + Returns: + A PaginatedCollection of data row ids + """ + query_str = """ + query getDataRowIdenfifiersBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) { + getDataRowIdentifiersBySavedQuery(input: { + savedQueryId: $id, + after: $from + first: $first + }) { + totalCount + nodes + { + id + globalKey + } + pageInfo { + endCursor + hasNextPage + } + } + } + """ + return PaginatedCollection( + client=self.client, + query=query_str, + params={'id': str(self.uid)}, + dereferencing=['getDataRowIdentifiersBySavedQuery', 'nodes'], + obj_class=lambda _, data_row_id_and_gk: CatalogSlice. + DataRowIdAndGlobalKey(data_row_id_and_gk.get('id'), + data_row_id_and_gk.get('globalKey', None)), + cursor_path=[ + 'getDataRowIdentifiersBySavedQuery', 'pageInfo', 'endCursor' + ]) + @experimental def export(self, task_name: Optional[str] = None,