From 676de2fe63d921501c0e9e65620905ae94c440a5 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Tue, 16 Jan 2024 21:03:01 -0800 Subject: [PATCH] Update ModelSlice: fix get_data_row_ids and add get_data_row_identifiers --- labelbox/client.py | 4 ++ labelbox/schema/slice.py | 103 +++++++++++++++++++++++++++------------ 2 files changed, 76 insertions(+), 31 deletions(-) diff --git a/labelbox/client.py b/labelbox/client.py index 4470190f9..c00a64c79 100644 --- a/labelbox/client.py +++ b/labelbox/client.py @@ -1696,6 +1696,10 @@ def get_model_slice(self, slice_id) -> ModelSlice: } """ res = self.execute(query_str, {"id": slice_id}) + if res is None or res["getSavedQuery"] is None: + raise labelbox.exceptions.ResourceNotFoundError( + ModelSlice, slice_id) + return Entity.ModelSlice(self, res["getSavedQuery"]) def delete_feature_schema_from_ontology( diff --git a/labelbox/schema/slice.py b/labelbox/schema/slice.py index 15fff2277..4ee7ec43b 100644 --- a/labelbox/schema/slice.py +++ b/labelbox/schema/slice.py @@ -29,12 +29,6 @@ class Slice(DbObject): updated_at = Field.DateTime("updated_at") filter = Field.Json("filter") - -class CatalogSlice(Slice): - """ - Represents a Slice used for filtering data rows in Catalog. - """ - @dataclass class DataRowIdAndGlobalKey: id: UniqueId @@ -44,6 +38,18 @@ def __init__(self, id: str, global_key: Optional[str]): self.id = UniqueId(id) self.global_key = GlobalKey(global_key) if global_key else None + def to_hash(self): + return { + "id": self.id.key, + "global_key": self.global_key.key if self.global_key else None + } + + +class CatalogSlice(Slice): + """ + Represents a Slice used for filtering data rows in Catalog. + """ + def get_data_row_ids(self) -> PaginatedCollection: """ Fetches all data row ids that match this Slice @@ -75,7 +81,7 @@ def get_data_row_ids(self) -> PaginatedCollection: return PaginatedCollection( client=self.client, query=query_str, - params={'id': self.uid}, + params={'id': str(self.uid)}, dereferencing=['getDataRowIdsBySavedQuery', 'nodes'], obj_class=lambda _, data_row_id: data_row_id, cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor']) @@ -85,7 +91,7 @@ def get_data_row_identifiers(self) -> PaginatedCollection: Fetches all data row ids and global keys (where defined) that match this Slice Returns: - A PaginatedCollection of data row ids + A PaginatedCollection of Slice.DataRowIdAndGlobalKey """ query_str = """ query getDataRowIdenfifiersBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) { @@ -112,9 +118,9 @@ def get_data_row_identifiers(self) -> PaginatedCollection: query=query_str, params={'id': str(self.uid)}, dereferencing=['getDataRowIdentifiersBySavedQuery', 'nodes'], - obj_class=lambda _, data_row_id_and_gk: CatalogSlice. - DataRowIdAndGlobalKey(data_row_id_and_gk.get('id'), - data_row_id_and_gk.get('globalKey', None)), + obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( + data_row_id_and_gk.get('id'), + data_row_id_and_gk.get('globalKey', None)), cursor_path=[ 'getDataRowIdentifiersBySavedQuery', 'pageInfo', 'endCursor' ]) @@ -224,6 +230,33 @@ class ModelSlice(Slice): Represents a Slice used for filtering data rows in Model. """ + @classmethod + def query_str(cls): + query_str = """ + query getDataRowIdenfifiersBySavedModelQueryPyApi($id: ID!, $from: DataRowIdentifierCursorInput, $first: Int!) { + getDataRowIdentifiersBySavedModelQuery(input: { + savedQueryId: $id, + after: $from + first: $first + }) { + totalCount + nodes + { + id + globalKey + } + pageInfo { + endCursor { + dataRowId + globalKey + } + hasNextPage + } + } + } + """ + return query_str + def get_data_row_ids(self) -> PaginatedCollection: """ Fetches all data row ids that match this Slice @@ -231,26 +264,34 @@ def get_data_row_ids(self) -> PaginatedCollection: Returns: A PaginatedCollection of data row ids """ - query_str = """ - query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) { - getDataRowIdsBySavedQuery(input: { - savedQueryId: $id, - after: $from - first: $first - }) { - totalCount - nodes - pageInfo { - endCursor - hasNextPage - } - } - } + return PaginatedCollection( + client=self.client, + query=ModelSlice.query_str(), + params={'id': str(self.uid)}, + dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], + obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get('id' + ), + cursor_path=[ + 'getDataRowIdentifiersBySavedModelQuery', 'pageInfo', + 'endCursor' + ]) + + def get_data_row_identifiers(self) -> PaginatedCollection: + """ + Fetches all data row ids and global keys (where defined) that match this Slice + + Returns: + A PaginatedCollection of Slice.DataRowIdAndGlobalKey """ return PaginatedCollection( client=self.client, - query=query_str, - params={'id': self.uid}, - dereferencing=['getDataRowIdsBySavedQuery', 'nodes'], - obj_class=lambda _, data_row_id: data_row_id, - cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor']) + query=ModelSlice.query_str(), + params={'id': str(self.uid)}, + dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], + obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( + data_row_id_and_gk.get('id'), + data_row_id_and_gk.get('globalKey', None)), + cursor_path=[ + 'getDataRowIdentifiersBySavedModelQuery', 'pageInfo', + 'endCursor' + ])