diff --git a/labelbox/schema/export_params.py b/labelbox/schema/export_params.py new file mode 100644 index 000000000..234be78d5 --- /dev/null +++ b/labelbox/schema/export_params.py @@ -0,0 +1,25 @@ +import sys + +from typing import Optional +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +class DataRowParams(TypedDict): + data_row_details: Optional[bool] + media_attributes: Optional[bool] + metadata_fields: Optional[bool] + attachments: Optional[bool] + + +class ProjectExportParams(DataRowParams): + project_details: Optional[bool] + label_details: Optional[bool] + performance_details: Optional[bool] + + +class ModelRunExportParams(DataRowParams): + # TODO: Add model run fields + pass diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 87b0dc557..701fd38b7 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -12,6 +12,9 @@ from labelbox.orm.query import results_query_part from labelbox.orm.model import Field, Relationship, Entity from labelbox.orm.db_object import DbObject, experimental +from labelbox.schema.export_params import ModelRunExportParams +from labelbox.schema.task import Task +from labelbox.schema.user import User if TYPE_CHECKING: from labelbox import MEAPredictionImport @@ -446,6 +449,63 @@ def export_labels( self.uid) time.sleep(sleep_time) + """ + Creates a model run export task with the given params and returns the task. + + >>> export_task = export_v2("my_export_task", params={"media_attributes": True}) + + """ + + def export_v2(self, task_name: str, + params: Optional[ModelRunExportParams]) -> Task: + _params = params or {} + mutation_name = "exportDataRowsInModelRun" + create_task_query_str = """mutation exportDataRowsInModelRunPyApi($input: ExportDataRowsInModelRunInput!){ + %s(input: $input) {taskId} } + """ % (mutation_name) + params = { + "input": { + "taskName": task_name, + "filters": { + "modelRunId": self.uid + }, + "params": { + "includeAttachments": + _params.get('attachments', False), + "includeMediaAttributes": + _params.get('media_attributes', False), + "includeMetadata": + _params.get('metadata_fields', False), + "includeDataRowDetails": + _params.get('data_row_details', False), + # Arguments locked based on exectuion context + "includeProjectDetails": + False, + "includeLabels": + False, + "includePerformanceDetails": + False, + }, + } + } + res = self.client.execute( + create_task_query_str, + params, + ) + res = res[mutation_name] + task_id = res["taskId"] + user: User = self.client.get_user() + tasks: List[Task] = list( + user.created_tasks(where=Entity.Task.uid == task_id)) + # Cache user in a private variable as the relationship can't be + # resolved due to server-side limitations (see Task.created_by) + # for more info. + if len(tasks) != 1: + raise ResourceNotFoundError(Entity.Task, task_id) + task: Task = tasks[0] + task._user = user + return task + class ModelRunDataRow(DbObject): label_id = Field.String("label_id") diff --git a/labelbox/schema/task.py b/labelbox/schema/task.py index bd35261b9..7ce4279e2 100644 --- a/labelbox/schema/task.py +++ b/labelbox/schema/task.py @@ -81,6 +81,9 @@ def wait_till_done(self, timeout_seconds=300) -> None: def errors(self) -> Optional[Dict[str, Any]]: """ Fetch the error associated with an import task. """ + # TODO: We should handle error messages for export v2 tasks in the future. + if self.name != 'JSON Import': + return None if self.status == "FAILED": result = self._fetch_remote_json() return result["error"] diff --git a/tests/integration/annotation_import/test_model_run.py b/tests/integration/annotation_import/test_model_run.py index d7c3f925c..275719372 100644 --- a/tests/integration/annotation_import/test_model_run.py +++ b/tests/integration/annotation_import/test_model_run.py @@ -1,8 +1,11 @@ +import json import time import os import pytest from collections import Counter + +import requests from labelbox import DataSplit, ModelRun @@ -114,6 +117,52 @@ def test_model_run_export_labels(model_run_with_model_run_data_rows): assert len(labels) == 3 +def test_model_run_export_v2(model_run_with_model_run_data_rows, + configured_project): + task_name = "test_task" + + media_attributes = True + params = {"media_attributes": media_attributes} + task = model_run_with_model_run_data_rows.export_v2(task_name, + params=params) + assert task.name == task_name + task.wait_till_done() + assert task.status == "COMPLETE" + + def download_result(result_url): + response = requests.get(result_url) + response.raise_for_status() + data = [json.loads(line) for line in response.text.splitlines()] + return data + + task_results = download_result(task.result_url) + + label_ids = [label.uid for label in configured_project.labels()] + label_ids_set = set(label_ids) + + assert len(task_results) == len(label_ids) + for task_result in task_results: + assert len(task_result['errors']) == 0 + # Check export param handling + if media_attributes: + assert 'media_attributes' in task_result and task_result[ + 'media_attributes'] is not None + else: + assert 'media_attributes' not in task_result or task_result[ + 'media_attributes'] is None + model_run = task_result['models'][ + model_run_with_model_run_data_rows.model_id]['model_runs'][ + model_run_with_model_run_data_rows.uid] + task_label_ids_set = set( + map(lambda label: label['id'], model_run['labels'])) + task_prediction_ids_set = set( + map(lambda prediction: prediction['id'], model_run['predictions'])) + for label_id in task_label_ids_set: + assert label_id in label_ids_set + for prediction_id in task_prediction_ids_set: + assert prediction_id in label_ids_set + + @pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", reason="does not work for onprem") def test_model_run_status(model_run_with_model_run_data_rows):