Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,10 @@ def get_model_slice(self, slice_id) -> ModelSlice:
}
"""
res = self.execute(query_str, {"id": slice_id})
if res is None or res["getSavedQuery"] is None:
raise labelbox.exceptions.ResourceNotFoundError(
ModelSlice, slice_id)

return Entity.ModelSlice(self, res["getSavedQuery"])

def delete_feature_schema_from_ontology(
Expand Down
103 changes: 72 additions & 31 deletions labelbox/schema/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ class Slice(DbObject):
updated_at = Field.DateTime("updated_at")
filter = Field.Json("filter")


class CatalogSlice(Slice):
"""
Represents a Slice used for filtering data rows in Catalog.
"""

@dataclass
class DataRowIdAndGlobalKey:
id: UniqueId
Expand All @@ -44,6 +38,18 @@ def __init__(self, id: str, global_key: Optional[str]):
self.id = UniqueId(id)
self.global_key = GlobalKey(global_key) if global_key else None

def to_hash(self):
return {
"id": self.id.key,
"global_key": self.global_key.key if self.global_key else None
}


class CatalogSlice(Slice):
"""
Represents a Slice used for filtering data rows in Catalog.
"""

def get_data_row_ids(self) -> PaginatedCollection:
"""
Fetches all data row ids that match this Slice
Expand Down Expand Up @@ -75,7 +81,7 @@ def get_data_row_ids(self) -> PaginatedCollection:
return PaginatedCollection(
client=self.client,
query=query_str,
params={'id': self.uid},
params={'id': str(self.uid)},
dereferencing=['getDataRowIdsBySavedQuery', 'nodes'],
obj_class=lambda _, data_row_id: data_row_id,
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])
Expand All @@ -85,7 +91,7 @@ def get_data_row_identifiers(self) -> PaginatedCollection:
Fetches all data row ids and global keys (where defined) that match this Slice

Returns:
A PaginatedCollection of data row ids
A PaginatedCollection of Slice.DataRowIdAndGlobalKey
"""
query_str = """
query getDataRowIdenfifiersBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
Expand All @@ -112,9 +118,9 @@ def get_data_row_identifiers(self) -> PaginatedCollection:
query=query_str,
params={'id': str(self.uid)},
dereferencing=['getDataRowIdentifiersBySavedQuery', 'nodes'],
obj_class=lambda _, data_row_id_and_gk: CatalogSlice.
DataRowIdAndGlobalKey(data_row_id_and_gk.get('id'),
data_row_id_and_gk.get('globalKey', None)),
obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey(
data_row_id_and_gk.get('id'),
data_row_id_and_gk.get('globalKey', None)),
cursor_path=[
'getDataRowIdentifiersBySavedQuery', 'pageInfo', 'endCursor'
])
Expand Down Expand Up @@ -224,33 +230,68 @@ class ModelSlice(Slice):
Represents a Slice used for filtering data rows in Model.
"""

@classmethod
def query_str(cls):
query_str = """
query getDataRowIdenfifiersBySavedModelQueryPyApi($id: ID!, $from: DataRowIdentifierCursorInput, $first: Int!) {
getDataRowIdentifiersBySavedModelQuery(input: {
savedQueryId: $id,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comma might/should not be necessary

after: $from
first: $first
}) {
totalCount
nodes
{
id
globalKey
}
pageInfo {
endCursor {
dataRowId
globalKey
}
hasNextPage
}
}
}
"""
return query_str

def get_data_row_ids(self) -> PaginatedCollection:
"""
Fetches all data row ids that match this Slice

Returns:
A PaginatedCollection of data row ids
"""
query_str = """
query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
getDataRowIdsBySavedQuery(input: {
savedQueryId: $id,
after: $from
first: $first
}) {
totalCount
nodes
pageInfo {
endCursor
hasNextPage
}
}
}
return PaginatedCollection(
client=self.client,
query=ModelSlice.query_str(),
params={'id': str(self.uid)},
dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'],
obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get('id'
),
cursor_path=[
'getDataRowIdentifiersBySavedModelQuery', 'pageInfo',
'endCursor'
])

def get_data_row_identifiers(self) -> PaginatedCollection:
"""
Fetches all data row ids and global keys (where defined) that match this Slice

Returns:
A PaginatedCollection of Slice.DataRowIdAndGlobalKey
"""
return PaginatedCollection(
client=self.client,
query=query_str,
params={'id': self.uid},
dereferencing=['getDataRowIdsBySavedQuery', 'nodes'],
obj_class=lambda _, data_row_id: data_row_id,
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])
query=ModelSlice.query_str(),
params={'id': str(self.uid)},
dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'],
obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey(
data_row_id_and_gk.get('id'),
data_row_id_and_gk.get('globalKey', None)),
cursor_path=[
'getDataRowIdentifiersBySavedModelQuery', 'pageInfo',
'endCursor'
])