diff --git a/labelbox/__init__.py b/labelbox/__init__.py index d9acd8374..8d58e46cb 100644 --- a/labelbox/__init__.py +++ b/labelbox/__init__.py @@ -27,6 +27,6 @@ from labelbox.schema.resource_tag import ResourceTag from labelbox.schema.project_resource_tag import ProjectResourceTag from labelbox.schema.media_type import MediaType -from labelbox.schema.slice import Slice, CatalogSlice +from labelbox.schema.slice import Slice, CatalogSlice, ModelSlice from labelbox.schema.queue_mode import QueueMode from labelbox.schema.task_queue import TaskQueue diff --git a/labelbox/client.py b/labelbox/client.py index 172184608..5af41ffbb 100644 --- a/labelbox/client.py +++ b/labelbox/client.py @@ -33,7 +33,7 @@ from labelbox.schema.user import User from labelbox.schema.project import Project from labelbox.schema.role import Role -from labelbox.schema.slice import CatalogSlice +from labelbox.schema.slice import CatalogSlice, ModelSlice from labelbox.schema.queue_mode import QueueMode from labelbox.schema.media_type import MediaType, get_media_type_validation_error @@ -1384,3 +1384,27 @@ def get_catalog_slice(self, slice_id) -> CatalogSlice: """ res = self.execute(query_str, {'id': slice_id}) return Entity.CatalogSlice(self, res['getSavedQuery']) + + def get_model_slice(self, slice_id) -> ModelSlice: + """ + Fetches a Model Slice by ID. + + Args: + slice_id (str): The ID of the Slice + Returns: + ModelSlice + """ + query_str = """ + query getSavedQueryPyApi($id: ID!) { + getSavedQuery(id: $id) { + id + name + description + filter + createdAt + updatedAt + } + } + """ + res = self.execute(query_str, {"id": slice_id}) + return Entity.ModelSlice(self, res["getSavedQuery"]) diff --git a/labelbox/orm/model.py b/labelbox/orm/model.py index f4f09d8c8..f2c5b7a93 100644 --- a/labelbox/orm/model.py +++ b/labelbox/orm/model.py @@ -378,6 +378,7 @@ class Entity(metaclass=EntityMeta): Project: Type[labelbox.Project] Batch: Type[labelbox.Batch] CatalogSlice: Type[labelbox.CatalogSlice] + ModelSlice: Type[labelbox.ModelSlice] TaskQueue: Type[labelbox.TaskQueue] @classmethod diff --git a/labelbox/schema/slice.py b/labelbox/schema/slice.py index 5d6a0fb62..ab70f4b1b 100644 --- a/labelbox/schema/slice.py +++ b/labelbox/schema/slice.py @@ -15,6 +15,7 @@ class Slice(DbObject): updated_at (datetime) filter (json) """ + name = Field.String("name") description = Field.String("description") created_at = Field.DateTime("created_at") @@ -57,3 +58,40 @@ def get_data_row_ids(self) -> PaginatedCollection: dereferencing=['getDataRowIdsBySavedQuery', 'nodes'], obj_class=lambda _, data_row_id: data_row_id, cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor']) + + +class ModelSlice(Slice): + """ + Represents a Slice used for filtering data rows in Model. + """ + + def get_data_row_ids(self) -> PaginatedCollection: + """ + Fetches all data row ids that match this Slice + + Returns: + A PaginatedCollection of data row ids + """ + query_str = """ + query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) { + getDataRowIdsBySavedQuery(input: { + savedQueryId: $id, + after: $from + first: $first + }) { + totalCount + nodes + pageInfo { + endCursor + hasNextPage + } + } + } + """ + return PaginatedCollection( + client=self.client, + query=query_str, + params={'id': self.uid}, + dereferencing=['getDataRowIdsBySavedQuery', 'nodes'], + obj_class=lambda _, data_row_id: data_row_id, + cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])