Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 59 additions & 5 deletions labelbox/schema/slice.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional, List
from labelbox.exceptions import ResourceNotFoundError
from dataclasses import dataclass
from typing import Optional
import warnings
from labelbox.orm.db_object import DbObject, experimental
from labelbox.orm.model import Entity, Field
from labelbox.orm.model import Field
from labelbox.pagination import PaginatedCollection
from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params
from labelbox.schema.export_task import ExportTask
from labelbox.schema.identifiable import GlobalKey, UniqueId
from labelbox.schema.task import Task
from labelbox.schema.user import User


class Slice(DbObject):
Expand Down Expand Up @@ -34,13 +35,27 @@ class CatalogSlice(Slice):
Represents a Slice used for filtering data rows in Catalog.
"""

@dataclass
class DataRowIdAndGlobalKey:
id: UniqueId
global_key: Optional[GlobalKey]

def __init__(self, id: str, global_key: Optional[str]):
self.id = UniqueId(id)
self.global_key = GlobalKey(global_key) if global_key else None

def get_data_row_ids(self) -> PaginatedCollection:
"""
Fetches all data row ids that match this Slice

Returns:
A PaginatedCollection of data row ids
A PaginatedCollection of mapping of data row ids to global keys
"""

warnings.warn(
"get_data_row_ids will be deprecated. Use get_data_row_identifiers instead"
)

query_str = """
query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
getDataRowIdsBySavedQuery(input: {
Expand All @@ -65,6 +80,45 @@ def get_data_row_ids(self) -> PaginatedCollection:
obj_class=lambda _, data_row_id: data_row_id,
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])

def get_data_row_identifiers(self) -> PaginatedCollection:
"""
Fetches all data row ids and global keys (where defined) that match this Slice

Returns:
A PaginatedCollection of data row ids
"""
query_str = """
query getDataRowIdenfifiersBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
getDataRowIdentifiersBySavedQuery(input: {
savedQueryId: $id,
after: $from
first: $first
}) {
totalCount
nodes
{
id
globalKey
}
pageInfo {
endCursor
hasNextPage
}
}
}
"""
return PaginatedCollection(
client=self.client,
query=query_str,
params={'id': str(self.uid)},
dereferencing=['getDataRowIdentifiersBySavedQuery', 'nodes'],
obj_class=lambda _, data_row_id_and_gk: CatalogSlice.
DataRowIdAndGlobalKey(data_row_id_and_gk.get('id'),
data_row_id_and_gk.get('globalKey', None)),
cursor_path=[
'getDataRowIdentifiersBySavedQuery', 'pageInfo', 'endCursor'
])

@experimental
def export(self,
task_name: Optional[str] = None,
Expand Down