Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion labelbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@
from labelbox.schema.resource_tag import ResourceTag
from labelbox.schema.project_resource_tag import ProjectResourceTag
from labelbox.schema.media_type import MediaType
from labelbox.schema.slice import Slice, CatalogSlice
from labelbox.schema.slice import Slice, CatalogSlice, ModelSlice
from labelbox.schema.queue_mode import QueueMode
from labelbox.schema.task_queue import TaskQueue
26 changes: 25 additions & 1 deletion labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from labelbox.schema.user import User
from labelbox.schema.project import Project
from labelbox.schema.role import Role
from labelbox.schema.slice import CatalogSlice
from labelbox.schema.slice import CatalogSlice, ModelSlice
from labelbox.schema.queue_mode import QueueMode

from labelbox.schema.media_type import MediaType, get_media_type_validation_error
Expand Down Expand Up @@ -1384,3 +1384,27 @@ def get_catalog_slice(self, slice_id) -> CatalogSlice:
"""
res = self.execute(query_str, {'id': slice_id})
return Entity.CatalogSlice(self, res['getSavedQuery'])

def get_model_slice(self, slice_id) -> ModelSlice:
"""
Fetches a Model Slice by ID.

Args:
slice_id (str): The ID of the Slice
Returns:
ModelSlice
"""
query_str = """
query getSavedQueryPyApi($id: ID!) {
getSavedQuery(id: $id) {
id
name
description
filter
createdAt
updatedAt
}
}
"""
res = self.execute(query_str, {"id": slice_id})
return Entity.ModelSlice(self, res["getSavedQuery"])
1 change: 1 addition & 0 deletions labelbox/orm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ class Entity(metaclass=EntityMeta):
Project: Type[labelbox.Project]
Batch: Type[labelbox.Batch]
CatalogSlice: Type[labelbox.CatalogSlice]
ModelSlice: Type[labelbox.ModelSlice]
TaskQueue: Type[labelbox.TaskQueue]

@classmethod
Expand Down
38 changes: 38 additions & 0 deletions labelbox/schema/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Slice(DbObject):
updated_at (datetime)
filter (json)
"""

name = Field.String("name")
description = Field.String("description")
created_at = Field.DateTime("created_at")
Expand Down Expand Up @@ -57,3 +58,40 @@ def get_data_row_ids(self) -> PaginatedCollection:
dereferencing=['getDataRowIdsBySavedQuery', 'nodes'],
obj_class=lambda _, data_row_id: data_row_id,
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])


class ModelSlice(Slice):
"""
Represents a Slice used for filtering data rows in Model.
"""

def get_data_row_ids(self) -> PaginatedCollection:
"""
Fetches all data row ids that match this Slice

Returns:
A PaginatedCollection of data row ids
"""
query_str = """
query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
getDataRowIdsBySavedQuery(input: {
savedQueryId: $id,
after: $from
first: $first
}) {
totalCount
nodes
pageInfo {
endCursor
hasNextPage
}
}
}
"""
return PaginatedCollection(
client=self.client,
query=query_str,
params={'id': self.uid},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to gql above, $first is required, but I do not see it in the params. Is it ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob handled in PaginatedCollection?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dereferencing=['getDataRowIdsBySavedQuery', 'nodes'],
obj_class=lambda _, data_row_id: data_row_id,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not following here... according to PaginatedCollection documentation, obj_class is a class of an object...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That documentation is 4 years old, I think it's outdated. I was actually just copying the code from CatalogSlice here 😅

cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])