diff --git a/labelbox/schema/data_row.py b/labelbox/schema/data_row.py index d8111bb9f..434ce5016 100644 --- a/labelbox/schema/data_row.py +++ b/labelbox/schema/data_row.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union import json from labelbox.exceptions import ResourceNotFoundError @@ -7,6 +7,7 @@ from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable from labelbox.orm.model import Entity, Field, Relationship from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore +from labelbox.schema.export_filters import DatarowExportFilters, build_filters from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params from labelbox.schema.task import Task from labelbox.schema.user import User # type: ignore @@ -157,15 +158,21 @@ def create_attachment(self, @staticmethod def export_v2(client: 'Client', - data_rows: List['DataRow'], + data_rows: List[Union[str, 'DataRow']], task_name: Optional[str] = None, params: Optional[CatalogExportParams] = None) -> Task: """ Creates a data rows export task with the given list, params and returns the task. + Args: + client (Client): client to use to make the export request + data_rows (list of DataRow or str): list of data row objects or data row ids to export + task_name (str): name of remote task + params (CatalogExportParams): export params + >>> dataset = client.get_dataset(DATASET_ID) >>> task = DataRow.export_v2( - >>> data_rows_ids=[data_row.uid for data_row in dataset.data_rows.list()], + >>> data_rows=[data_row.uid for data_row in dataset.data_rows.list()], >>> filters={ >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"] @@ -198,19 +205,26 @@ def export_v2(client: 'Client', %s(input: $input) {taskId} } """ % (mutation_name) - data_rows_ids = [data_row.uid for data_row in data_rows] - search_query: List[Dict[str, Collection[str]]] = [] - search_query.append({ - "ids": data_rows_ids, - "operator": "is", - "type": "data_row_id" + data_row_ids = [] + if data_rows is not None: + for dr in data_rows: + if isinstance(dr, DataRow): + data_row_ids.append(dr.uid) + elif isinstance(dr, str): + data_row_ids.append(dr) + + filters = DatarowExportFilters({ + "last_activity_at": None, + "label_created_at": None, + "data_row_ids": data_row_ids, }) + search_query: List[Dict[str, Collection[str]]] = [] + search_query = build_filters(client, filters) - print(search_query) media_type_override = _params.get('media_type_override', None) if task_name is None: - task_name = f"Export v2: data rows (%s)" % len(data_rows_ids) + task_name = f"Export v2: data rows (%s)" % len(data_row_ids) query_params = { "input": { "taskName": task_name, diff --git a/labelbox/schema/export_filters.py b/labelbox/schema/export_filters.py index dc4d4e214..7f007606c 100644 --- a/labelbox/schema/export_filters.py +++ b/labelbox/schema/export_filters.py @@ -48,6 +48,10 @@ class DatasetExportFilters(SharedExportFilters): pass +class DatarowExportFilters(SharedExportFilters): + pass + + def validate_datetime(datetime_str: str) -> bool: """helper function to validate that datetime's format: "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" or ISO 8061 format "YYYY-MM-DDThh:mm:ss±hhmm" (Example: "2023-05-23T14:30:00+0530")""" diff --git a/tests/integration/test_data_rows.py b/tests/integration/test_data_rows.py index 36fc8c692..bb002b8e2 100644 --- a/tests/integration/test_data_rows.py +++ b/tests/integration/test_data_rows.py @@ -9,6 +9,7 @@ from labelbox import DataRow from labelbox.exceptions import MalformedQueryException +from labelbox.schema.export_filters import DatarowExportFilters from labelbox.schema.task import Task from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadataKind import labelbox.exceptions @@ -1001,8 +1002,17 @@ def test_export_data_rows(client, data_row, wait_for_data_row_processing): # Ensure created data rows are indexed data_row = wait_for_data_row_processing(client, data_row) time.sleep(7) # temp fix for ES indexing delay + task = DataRow.export_v2(client=client, data_rows=[data_row]) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None assert len(task.result) == 1 + assert task.result[0]['data_row']['id'] == data_row.uid + + task = DataRow.export_v2(client=client, data_rows=[data_row.uid]) + task.wait_till_done() + assert task.status == "COMPLETE" + assert task.errors is None + assert len(task.result) == 1 + assert task.result[0]['data_row']['id'] == data_row.uid