From b29e1d4f086fc63eb4b37b7124ee5e075adfe536 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Thu, 25 Jan 2024 14:35:27 -0800 Subject: [PATCH] Must pass model run id in order for model slices to work correctly --- labelbox/schema/slice.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/labelbox/schema/slice.py b/labelbox/schema/slice.py index e7a538da7..505585eb8 100644 --- a/labelbox/schema/slice.py +++ b/labelbox/schema/slice.py @@ -239,9 +239,10 @@ class ModelSlice(Slice): @classmethod def query_str(cls): query_str = """ - query getDataRowIdenfifiersBySavedModelQueryPyApi($id: ID!, $from: DataRowIdentifierCursorInput, $first: Int!) { + query getDataRowIdenfifiersBySavedModelQueryPyApi($id: ID!, $modelRunId: ID, $from: DataRowIdentifierCursorInput, $first: Int!) { getDataRowIdentifiersBySavedModelQuery(input: { savedQueryId: $id, + modelRunId: $modelRunId, after: $from first: $first }) { @@ -263,17 +264,23 @@ def query_str(cls): """ return query_str - def get_data_row_ids(self) -> PaginatedCollection: + def get_data_row_ids(self, model_run_id: str) -> PaginatedCollection: """ Fetches all data row ids that match this Slice + Params + model_run_id: str, required, uid or cuid of model run + Returns: A PaginatedCollection of data row ids """ return PaginatedCollection( client=self.client, query=ModelSlice.query_str(), - params={'id': str(self.uid)}, + params={ + 'id': str(self.uid), + 'modelRunId': model_run_id + }, dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get('id' ), @@ -282,17 +289,24 @@ def get_data_row_ids(self) -> PaginatedCollection: 'endCursor' ]) - def get_data_row_identifiers(self) -> PaginatedCollection: + def get_data_row_identifiers(self, + model_run_id: str) -> PaginatedCollection: """ Fetches all data row ids and global keys (where defined) that match this Slice + Params: + model_run_id : str, required, uid or cuid of model run + Returns: A PaginatedCollection of Slice.DataRowIdAndGlobalKey """ return PaginatedCollection( client=self.client, query=ModelSlice.query_str(), - params={'id': str(self.uid)}, + params={ + 'id': str(self.uid), + 'modelRunId': model_run_id + }, dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( data_row_id_and_gk.get('id'),