From 1ea2ad3bc257781522a3368816956d95ef7f6ce9 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Sun, 19 Mar 2023 02:30:26 +0100 Subject: [PATCH 1/7] Add catalog slice SDK support --- labelbox/schema/export_params.py | 11 +++- labelbox/schema/slice.py | 86 +++++++++++++++++++++++++++++++- 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/labelbox/schema/export_params.py b/labelbox/schema/export_params.py index fa54ef81d..4fcde5c2a 100644 --- a/labelbox/schema/export_params.py +++ b/labelbox/schema/export_params.py @@ -1,6 +1,6 @@ import sys -from typing import Optional +from typing import Optional, List from labelbox.schema.media_type import MediaType if sys.version_info >= (3, 8): @@ -22,6 +22,15 @@ class ProjectExportParams(DataRowParams): performance_details: Optional[bool] +class CatalogSliceExportParams(DataRowParams): + project_details: Optional[bool] + label_details: Optional[bool] + performance_details: Optional[bool] + model_runs_ids: Optional[List[str]] + projects_ids: Optional[List[str]] + pass + + class ModelRunExportParams(DataRowParams): # TODO: Add model run fields pass diff --git a/labelbox/schema/slice.py b/labelbox/schema/slice.py index ab70f4b1b..1010f9af3 100644 --- a/labelbox/schema/slice.py +++ b/labelbox/schema/slice.py @@ -1,6 +1,11 @@ +from typing import Optional, List +from labelbox.exceptions import ResourceNotFoundError from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field +from labelbox.orm.model import Entity, Field from labelbox.pagination import PaginatedCollection +from labelbox.schema.export_params import CatalogSliceExportParams +from labelbox.schema.task import Task +from labelbox.schema.user import User class Slice(DbObject): @@ -59,6 +64,85 @@ def get_data_row_ids(self) -> PaginatedCollection: obj_class=lambda _, data_row_id: data_row_id, cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor']) + def export_v2(self, + task_name: Optional[str] = None, + params: Optional[CatalogSliceExportParams] = None) -> Task: + """ + Creates a slice export task with the given params and returns the task. + >>> slice = client.get_catalog_slice("SLICE_ID") + >>> task = slice.export_v2( + >>> params={"performance_details": False, "label_details": True} + >>> ) + >>> task.wait_till_done() + >>> task.result + """ + + _params = params or CatalogSliceExportParams({ + "attachments": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_runs_ids": None, + "projects_ids": None, + }) + + mutation_name = "exportDataRowsInSlice" + create_task_query_str = """mutation exportDataRowsInSlicePyApi($input: ExportDataRowsInSliceInput!){ + %s(input: $input) {taskId} } + """ % (mutation_name) + + media_type_override = _params.get('media_type_override', None) + query_params = { + "input": { + "taskName": task_name, + "filters": { + "sliceId": self.uid + }, + "params": { + "mediaTypeOverride": + media_type_override.value + if media_type_override is not None else None, + "includeAttachments": + _params.get('attachments', False), + "includeMetadata": + _params.get('metadata_fields', False), + "includeDataRowDetails": + _params.get('data_row_details', False), + "includeProjectDetails": + _params.get('project_details', False), + "includePerformanceDetails": + _params.get('performance_details', False), + "includeLabelDetails": + _params.get('label_details', False), + "projectIds": + _params.get('projects_ids', None), + "modelRunIds": + _params.get('model_runs_ids', None), + }, + } + } + + res = self.client.execute( + create_task_query_str, + query_params, + ) + res = res[mutation_name] + task_id = res["taskId"] + user: User = self.client.get_user() + tasks: List[Task] = list( + user.created_tasks(where=Entity.Task.uid == task_id)) + # Cache user in a private variable as the relationship can't be + # resolved due to server-side limitations (see Task.created_by) + # for more info. + if len(tasks) != 1: + raise ResourceNotFoundError(Entity.Task, task_id) + task: Task = tasks[0] + task._user = user + return task + class ModelSlice(Slice): """ From 784a12190459bd50dd0b8a75929c60ebfad69da7 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Sun, 19 Mar 2023 03:08:16 +0100 Subject: [PATCH 2/7] Add export Dataset.export_v2 method, improve docs --- labelbox/schema/dataset.py | 195 +++++++++++++++++++++++++++++- labelbox/schema/export_filters.py | 10 +- labelbox/schema/export_params.py | 2 +- labelbox/schema/project.py | 6 +- labelbox/schema/slice.py | 6 +- 5 files changed, 207 insertions(+), 12 deletions(-) diff --git a/labelbox/schema/dataset.py b/labelbox/schema/dataset.py index bbabd6259..0ce9211d4 100644 --- a/labelbox/schema/dataset.py +++ b/labelbox/schema/dataset.py @@ -1,4 +1,4 @@ -from typing import Generator, List, Union, Any, TYPE_CHECKING +from typing import Collection, Dict, Generator, List, Optional, Union, Any, TYPE_CHECKING import os import json import logging @@ -17,9 +17,12 @@ from labelbox.orm.model import Entity, Field, Relationship from labelbox.orm import query from labelbox.exceptions import MalformedQueryException - -if TYPE_CHECKING: - from labelbox import Task, User, DataRow +from labelbox.schema.data_row import DataRow +from labelbox.schema.export_filters import DatasetExportFilters, SharedExportFilters +from labelbox.schema.export_params import CatalogExportParams +from labelbox.schema.project import _validate_datetime +from labelbox.schema.task import Task +from labelbox.schema.user import User logger = logging.getLogger(__name__) @@ -534,3 +537,187 @@ def export_data_rows(self, logger.debug("Dataset '%s' data row export, waiting for server...", self.uid) time.sleep(sleep_time) + + def export_v2(self, + task_name: Optional[str] = None, + filters: Optional[DatasetExportFilters] = None, + params: Optional[CatalogExportParams] = None) -> Task: + """ + Creates a dataset export task with the given params and returns the task. + + >>> dataset = client.get_dataset(DATASET_ID) + >>> task = dataset.export_v2( + >>> filters={ + >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], + >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"] + >>> }, + >>> params={ + >>> "performance_details": False, + >>> "label_details": True + >>> }) + >>> task.wait_till_done() + >>> task.result + """ + + _params = params or CatalogExportParams({ + "attachments": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_runs_ids": None, + "project_ids": None, + }) + + _filters = filters or DatasetExportFilters({ + "last_activity_at": None, + "label_created_at": None + }) + + def _get_timezone() -> str: + timezone_query_str = """query CurrentUserPyApi { user { timezone } }""" + tz_res = self.client.execute(timezone_query_str) + return tz_res["user"]["timezone"] or "UTC" + + timezone: Optional[str] = None + + mutation_name = "exportDataRowsInCatalog" + create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){ + %s(input: $input) {taskId} } + """ % (mutation_name) + + search_query: List[Dict[str, Collection[str]]] = [] + search_query.append({ + "ids": [self.uid], + "operator": "is", + "type": "dataset" + }) + media_type_override = _params.get('media_type_override', None) + + if task_name is None: + task_name = f"Export v2: dataset - {self.name}" + query_params = { + "input": { + "taskName": task_name, + "filters": { + "searchQuery": { + "scope": None, + "query": search_query + } + }, + "params": { + "mediaTypeOverride": + media_type_override.value + if media_type_override is not None else None, + "includeAttachments": + _params.get('attachments', False), + "includeMetadata": + _params.get('metadata_fields', False), + "includeDataRowDetails": + _params.get('data_row_details', False), + "includeProjectDetails": + _params.get('project_details', False), + "includePerformanceDetails": + _params.get('performance_details', False), + "includeLabelDetails": + _params.get('label_details', False) + }, + } + } + + if "last_activity_at" in _filters and _filters[ + 'last_activity_at'] is not None: + if timezone is None: + timezone = _get_timezone() + values = _filters['last_activity_at'] + start, end = values + if (start is not None and end is not None): + [_validate_datetime(date) for date in values] + search_query.append({ + "type": "data_row_last_activity_at", + "value": { + "operator": "BETWEEN", + "timezone": timezone, + "value": { + "min": start, + "max": end + } + } + }) + elif (start is not None): + _validate_datetime(start) + search_query.append({ + "type": "data_row_last_activity_at", + "value": { + "operator": "GREATER_THAN_OR_EQUAL", + "timezone": timezone, + "value": start + } + }) + elif (end is not None): + _validate_datetime(end) + search_query.append({ + "type": "data_row_last_activity_at", + "value": { + "operator": "LESS_THAN_OR_EQUAL", + "timezone": timezone, + "value": end + } + }) + + if "label_created_at" in _filters and _filters[ + "label_created_at"] is not None: + if timezone is None: + timezone = _get_timezone() + values = _filters['label_created_at'] + start, end = values + if (start is not None and end is not None): + [_validate_datetime(date) for date in values] + search_query.append({ + "type": "labeled_at", + "value": { + "operator": "BETWEEN", + "value": { + "min": start, + "max": end + } + } + }) + elif (start is not None): + _validate_datetime(start) + search_query.append({ + "type": "labeled_at", + "value": { + "operator": "GREATER_THAN_OR_EQUAL", + "value": start + } + }) + elif (end is not None): + _validate_datetime(end) + search_query.append({ + "type": "labeled_at", + "value": { + "operator": "LESS_THAN_OR_EQUAL", + "value": end + } + }) + + res = self.client.execute( + create_task_query_str, + query_params, + ) + res = res[mutation_name] + task_id = res["taskId"] + user: User = self.client.get_user() + tasks: List[Task] = list( + user.created_tasks(where=Entity.Task.uid == task_id)) + # Cache user in a private variable as the relationship can't be + # resolved due to server-side limitations (see Task.created_by) + # for more info. + if len(tasks) != 1: + raise ResourceNotFoundError(Entity.Task, task_id) + task: Task = tasks[0] + task._user = user + return task diff --git a/labelbox/schema/export_filters.py b/labelbox/schema/export_filters.py index 928b3dcba..402f083e4 100644 --- a/labelbox/schema/export_filters.py +++ b/labelbox/schema/export_filters.py @@ -9,7 +9,7 @@ from typing import Tuple -class ProjectExportFilters(TypedDict): +class SharedExportFilters(TypedDict): label_created_at: Optional[Tuple[str, str]] """ Date range for labels created at Formatted "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" @@ -26,3 +26,11 @@ class ProjectExportFilters(TypedDict): >>> [None, "2050-01-01 00:00:00"] >>> ["2000-01-01 00:00:00", None] """ + + +class ProjectExportFilters(SharedExportFilters): + pass + + +class DatasetExportFilters(SharedExportFilters): + pass diff --git a/labelbox/schema/export_params.py b/labelbox/schema/export_params.py index 4fcde5c2a..de4e51eb6 100644 --- a/labelbox/schema/export_params.py +++ b/labelbox/schema/export_params.py @@ -22,7 +22,7 @@ class ProjectExportParams(DataRowParams): performance_details: Optional[bool] -class CatalogSliceExportParams(DataRowParams): +class CatalogExportParams(DataRowParams): project_details: Optional[bool] label_details: Optional[bool] performance_details: Optional[bool] diff --git a/labelbox/schema/project.py b/labelbox/schema/project.py index 43065a90f..825589eec 100644 --- a/labelbox/schema/project.py +++ b/labelbox/schema/project.py @@ -420,7 +420,7 @@ def export_v2(self, filters: Optional[ProjectExportFilters] = None, params: Optional[ProjectExportParams] = None) -> Task: """ - Creates a project run export task with the given params and returns the task. + Creates a project export task with the given params and returns the task. For more information visit: https://docs.labelbox.com/docs/exports-v2#export-from-a-project-python-sdk @@ -430,8 +430,8 @@ def export_v2(self, >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"] >>> }, >>> params={ - >>> "include_performance_details": False, - >>> "include_labels": True + >>> "performance_details": False, + >>> "label_details": True >>> }) >>> task.wait_till_done() >>> task.result diff --git a/labelbox/schema/slice.py b/labelbox/schema/slice.py index 1010f9af3..67f947def 100644 --- a/labelbox/schema/slice.py +++ b/labelbox/schema/slice.py @@ -3,7 +3,7 @@ from labelbox.orm.db_object import DbObject from labelbox.orm.model import Entity, Field from labelbox.pagination import PaginatedCollection -from labelbox.schema.export_params import CatalogSliceExportParams +from labelbox.schema.export_params import CatalogExportParams from labelbox.schema.task import Task from labelbox.schema.user import User @@ -66,7 +66,7 @@ def get_data_row_ids(self) -> PaginatedCollection: def export_v2(self, task_name: Optional[str] = None, - params: Optional[CatalogSliceExportParams] = None) -> Task: + params: Optional[CatalogExportParams] = None) -> Task: """ Creates a slice export task with the given params and returns the task. >>> slice = client.get_catalog_slice("SLICE_ID") @@ -77,7 +77,7 @@ def export_v2(self, >>> task.result """ - _params = params or CatalogSliceExportParams({ + _params = params or CatalogExportParams({ "attachments": False, "metadata_fields": False, "data_row_details": False, From 4f33e9631a9722c18e6e430dbb9e582f2b3e7d5a Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Sun, 19 Mar 2023 03:14:20 +0100 Subject: [PATCH 3/7] Add export dataset tests --- tests/integration/test_dataset.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/integration/test_dataset.py b/tests/integration/test_dataset.py index 002cbdce3..147ecc947 100644 --- a/tests/integration/test_dataset.py +++ b/tests/integration/test_dataset.py @@ -148,6 +148,20 @@ def test_data_row_export(dataset, image_url): assert set(result) == ids +def test_dataset_export_v2(dataset, image_url): + n_data_rows = 5 + ids = set() + for _ in range(n_data_rows): + ids.add(dataset.create_data_row(row_data=image_url)) + task = dataset.export_v2(params={ + "performance_details": False, + "label_details": True + }) + assert task.status == "COMPLETE" + assert task.errors is None + assert len(task.result) == n_data_rows + + def test_create_descriptor_file(dataset): import unittest.mock as mock with mock.patch.object(dataset.client, From cdea2747069ea6ff2b0079577130719b5f1c9ad5 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Sun, 19 Mar 2023 03:40:47 +0100 Subject: [PATCH 4/7] Add static method to export data rows + add tests --- labelbox/schema/data_row.py | 113 +++++++++++++++++++++++++++- labelbox/schema/dataset.py | 2 +- tests/integration/test_data_rows.py | 11 +++ tests/integration/test_dataset.py | 4 + 4 files changed, 127 insertions(+), 3 deletions(-) diff --git a/labelbox/schema/data_row.py b/labelbox/schema/data_row.py index 731f2eb98..9bb1ed69b 100644 --- a/labelbox/schema/data_row.py +++ b/labelbox/schema/data_row.py @@ -1,14 +1,18 @@ import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Collection, Dict, List, Optional import json +from labelbox.exceptions import ResourceNotFoundError from labelbox.orm import query from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable from labelbox.orm.model import Entity, Field, Relationship from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore +from labelbox.schema.export_params import CatalogExportParams +from labelbox.schema.task import Task +from labelbox.schema.user import User # type: ignore if TYPE_CHECKING: - from labelbox import AssetAttachment + from labelbox import AssetAttachment, Client logger = logging.getLogger(__name__) @@ -150,3 +154,108 @@ def create_attachment(self, }) return Entity.AssetAttachment(self.client, res["createDataRowAttachment"]) + + @staticmethod + def export_v2(client: 'Client', + data_rows: List['DataRow'], + task_name: Optional[str] = None, + params: Optional[CatalogExportParams] = None) -> Task: + """ + Creates a data rows export task with the given list, params and returns the task. + + >>> dataset = client.get_dataset(DATASET_ID) + >>> task = DataRow.export_v2( + >>> data_rows_ids=[data_row.uid for data_row in dataset.data_rows.list()], + >>> filters={ + >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], + >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"] + >>> }, + >>> params={ + >>> "performance_details": False, + >>> "label_details": True + >>> }) + >>> task.wait_till_done() + >>> task.result + """ + print('export start') + + _params = params or CatalogExportParams({ + "attachments": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_runs_ids": None, + "projects_ids": None, + }) + + mutation_name = "exportDataRowsInCatalog" + create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){ + %s(input: $input) {taskId} } + """ % (mutation_name) + + data_rows_ids = [data_row.uid for data_row in data_rows] + search_query: List[Dict[str, Collection[str]]] = [] + search_query.append({ + "ids": data_rows_ids, + "operator": "is", + "type": "data_row_id" + }) + + print(search_query) + media_type_override = _params.get('media_type_override', None) + + if task_name is None: + task_name = f"Export v2: data rows (%s)" % len(data_rows_ids) + query_params = { + "input": { + "taskName": task_name, + "filters": { + "searchQuery": { + "scope": None, + "query": search_query + } + }, + "params": { + "mediaTypeOverride": + media_type_override.value + if media_type_override is not None else None, + "includeAttachments": + _params.get('attachments', False), + "includeMetadata": + _params.get('metadata_fields', False), + "includeDataRowDetails": + _params.get('data_row_details', False), + "includeProjectDetails": + _params.get('project_details', False), + "includePerformanceDetails": + _params.get('performance_details', False), + "includeLabelDetails": + _params.get('label_details', False) + }, + } + } + + print('export execution') + print(client) + + res = client.execute( + create_task_query_str, + query_params, + ) + print(res) + res = res[mutation_name] + task_id = res["taskId"] + user: User = client.get_user() + tasks: List[Task] = list( + user.created_tasks(where=Entity.Task.uid == task_id)) + # Cache user in a private variable as the relationship can't be + # resolved due to server-side limitations (see Task.created_by) + # for more info. + if len(tasks) != 1: + raise ResourceNotFoundError(Entity.Task, task_id) + task: Task = tasks[0] + task._user = user + return task diff --git a/labelbox/schema/dataset.py b/labelbox/schema/dataset.py index 0ce9211d4..265792613 100644 --- a/labelbox/schema/dataset.py +++ b/labelbox/schema/dataset.py @@ -568,7 +568,7 @@ def export_v2(self, "label_details": False, "media_type_override": None, "model_runs_ids": None, - "project_ids": None, + "projects_ids": None, }) _filters = filters or DatasetExportFilters({ diff --git a/tests/integration/test_data_rows.py b/tests/integration/test_data_rows.py index 677348934..2e6ee6203 100644 --- a/tests/integration/test_data_rows.py +++ b/tests/integration/test_data_rows.py @@ -1,4 +1,5 @@ from tempfile import NamedTemporaryFile +import time import uuid from datetime import datetime import json @@ -962,3 +963,13 @@ def test_create_data_row_with_media_type(dataset, image_url): assert "Found invalid contents for media type: \'IMAGE\'" in str(exc.value) dataset.create_data_row(row_data=image_url, media_type="IMAGE") + + +def test_export_data_rows(client, datarow): + # Ensure created data rows are indexed + time.sleep(10) + task = DataRow.export_v2(client=client, data_rows=[datarow]) + task.wait_till_done() + assert task.status == "COMPLETE" + assert task.errors is None + assert len(task.result) == 1 diff --git a/tests/integration/test_dataset.py b/tests/integration/test_dataset.py index 147ecc947..1e7df6abb 100644 --- a/tests/integration/test_dataset.py +++ b/tests/integration/test_dataset.py @@ -1,4 +1,5 @@ import json +import time import pytest import requests from labelbox import Dataset @@ -153,10 +154,13 @@ def test_dataset_export_v2(dataset, image_url): ids = set() for _ in range(n_data_rows): ids.add(dataset.create_data_row(row_data=image_url)) + + time.sleep(10) task = dataset.export_v2(params={ "performance_details": False, "label_details": True }) + task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None assert len(task.result) == n_data_rows From 5a372fcfd1c4ffc92b7ac500ec9d2378411c8d2e Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 20 Mar 2023 12:36:25 +0100 Subject: [PATCH 5/7] Add test_slice, fix get slice --- labelbox/client.py | 3 +-- tests/integration/test_slice.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 tests/integration/test_slice.py diff --git a/labelbox/client.py b/labelbox/client.py index c7a1eb3c7..4f79a4b86 100644 --- a/labelbox/client.py +++ b/labelbox/client.py @@ -1559,8 +1559,7 @@ def get_catalog_slice(self, slice_id) -> CatalogSlice: Returns: CatalogSlice """ - query_str = """ - query getSavedQueryPyApi($id: ID!) { + query_str = """query getSavedQueryPyApi($id: ID!) { getSavedQuery(id: $id) { id name diff --git a/tests/integration/test_slice.py b/tests/integration/test_slice.py new file mode 100644 index 000000000..609dc5aa6 --- /dev/null +++ b/tests/integration/test_slice.py @@ -0,0 +1,15 @@ +import pytest + + +def test_export_v2_slice(client): + # Since we don't have CRUD for slices, we'll just use the one that's already there + SLICE_ID = "clfgqf1c72mk107zx6ypo9bse" + slice = client.get_catalog_slice(SLICE_ID) + task = slice.export_v2(params={ + "performance_details": False, + "label_details": True + }) + task.wait_till_done() + assert task.status == "COMPLETE" + assert task.errors is None + assert len(task.result) != 0 From 8437b320dad25bc77edae1202f6e1ee4984f02ae Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 20 Mar 2023 13:47:34 +0100 Subject: [PATCH 6/7] Remove debug code --- labelbox/schema/data_row.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/labelbox/schema/data_row.py b/labelbox/schema/data_row.py index 9bb1ed69b..42a34a5bf 100644 --- a/labelbox/schema/data_row.py +++ b/labelbox/schema/data_row.py @@ -238,9 +238,6 @@ def export_v2(client: 'Client', } } - print('export execution') - print(client) - res = client.execute( create_task_query_str, query_params, From 750cf3d7b8250e91b8899a5b45dba6cf164387b6 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 20 Mar 2023 14:35:35 +0100 Subject: [PATCH 7/7] Skip create slice test --- tests/integration/test_slice.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/test_slice.py b/tests/integration/test_slice.py index 609dc5aa6..9b1727393 100644 --- a/tests/integration/test_slice.py +++ b/tests/integration/test_slice.py @@ -1,6 +1,8 @@ import pytest +@pytest.mark.skip( + 'Skipping until we have a way to create slices programatically') def test_export_v2_slice(client): # Since we don't have CRUD for slices, we'll just use the one that's already there SLICE_ID = "clfgqf1c72mk107zx6ypo9bse"