From 798f518fda75ea632c5003260d63800ea2aa1edb Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 11:34:07 +0100 Subject: [PATCH 01/17] Add Model Run exports --- labelbox/schema/filters.py | 20 +++++++++++++ labelbox/schema/model_run.py | 58 ++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 labelbox/schema/filters.py diff --git a/labelbox/schema/filters.py b/labelbox/schema/filters.py new file mode 100644 index 000000000..d809ab096 --- /dev/null +++ b/labelbox/schema/filters.py @@ -0,0 +1,20 @@ +from typing import Optional, TypedDict + + +class DataRowFilter(TypedDict): + data_row_details: Optional[bool] + media_attributes: Optional[bool] + metadata_fields: Optional[bool] + attachments: Optional[bool] + global_issues: Optional[bool] + + +class ProjectExportFilter(DataRowFilter): + project_details: Optional[bool] + label_details: Optional[bool] + performance_details: Optional[bool] + + +class ModelRunExportFilter(DataRowFilter): + # TODO: Add model run fields + pass diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 87b0dc557..03aba9602 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -446,6 +446,64 @@ def export_labels( self.uid) time.sleep(sleep_time) + """ + Creates a model run export task with the given filter and returns the task. + + >>> export_task = export_labels_v2("my_export_task", filter={"media_attributes": True}) + + """ + + def export_labels_v2(self, task_name: str, + filter: Optional[ModelRunExportFilter]) -> str: + mutation_name = "exportDataRows" + create_task_query_str = """mutation exportDataRowsPyApi($input: ExportDataRowsInput!){ + %s(input: $input) {taskId} + """ % (mutation_name) + res = self.client.execute( + create_task_query_str, + { + "input": { + "taskName": task_name, + "filters": { + modelRunIds: [self.uid] + }, + "params": { + includeAttachments: + filter["attachments"] + if filter and "attachments" in filter else False, + includeMediaAttributes: + filter['media_attributes'] if filter and + 'media_attributes' in filter else False, + includeMetadata: + filter['metadata'] + if filter and 'metadata' in filter else False, + # Arguments locked based on exectuion context + includeModelRuns: + True, + includeProjectDetails: + False, + includeLabels: + False, + includePerformanceDetails: + False, + }, + } + }, + ) + task_id = res[mutation_name] + task_id = res["taskId"] + user: User = self.client.get_user() + tasks: List[Task] = list( + user.created_tasks(where=Entity.Task.uid == task_id)) + # Cache user in a private variable as the relationship can't be + # resolved due to server-side limitations (see Task.created_by) + # for more info. + if len(tasks) != 1: + raise ResourceNotFoundError(Entity.Task, task_id) + task: Task = tasks[0] + task._user = user + return task + class ModelRunDataRow(DbObject): label_id = Field.String("label_id") From 18277f11d9add86ae58bd9325ae3e19cf158313e Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 11:53:27 +0100 Subject: [PATCH 02/17] Fix 3.7 --- labelbox/schema/filters.py | 5 ++++- labelbox/schema/model_run.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/labelbox/schema/filters.py b/labelbox/schema/filters.py index d809ab096..1ad0ee2d1 100644 --- a/labelbox/schema/filters.py +++ b/labelbox/schema/filters.py @@ -1,4 +1,7 @@ -from typing import Optional, TypedDict +try: + from typing import Optional, TypedDict +except: + from typing_extensions import Optional, TypedDict class DataRowFilter(TypedDict): diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 03aba9602..aa5bdefbe 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -477,6 +477,9 @@ def export_labels_v2(self, task_name: str, includeMetadata: filter['metadata'] if filter and 'metadata' in filter else False, + globalIssues: + filter["global_issues"] + if filter and 'global_issues' in filter else False, # Arguments locked based on exectuion context includeModelRuns: True, From 2d893269d333f9ef95479bdc56d57ac9d5bf4cae Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 12:00:14 +0100 Subject: [PATCH 03/17] Fix 3.7 v2 --- labelbox/schema/filters.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/labelbox/schema/filters.py b/labelbox/schema/filters.py index 1ad0ee2d1..be53b5ae3 100644 --- a/labelbox/schema/filters.py +++ b/labelbox/schema/filters.py @@ -1,7 +1,12 @@ try: - from typing import Optional, TypedDict + from typing import Optional except: - from typing_extensions import Optional, TypedDict + from typing_extensions import Optional + +try: + from typing import TypedDict +except: + from typing_extensions import TypedDict class DataRowFilter(TypedDict): From 3fa9a9ec25fd20ea6be1f53fd1afff232d296c2c Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 12:06:33 +0100 Subject: [PATCH 04/17] Fix 3.7 v3 --- labelbox/schema/filters.py | 13 +++++-------- labelbox/schema/model_run.py | 1 + 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/labelbox/schema/filters.py b/labelbox/schema/filters.py index be53b5ae3..bac940e0a 100644 --- a/labelbox/schema/filters.py +++ b/labelbox/schema/filters.py @@ -1,12 +1,9 @@ -try: - from typing import Optional -except: - from typing_extensions import Optional +import sys -try: - from typing import TypedDict -except: - from typing_extensions import TypedDict +if sys.version_info >= (3, 8): + from typing import TypedDict, Optional +else: + from typing_extensions import TypedDict, Optional class DataRowFilter(TypedDict): diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index aa5bdefbe..385dfa559 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -12,6 +12,7 @@ from labelbox.orm.query import results_query_part from labelbox.orm.model import Field, Relationship, Entity from labelbox.orm.db_object import DbObject, experimental +from labelbox.schema.filters import ModelRunExportFilter if TYPE_CHECKING: from labelbox import MEAPredictionImport From b26f235cff9e4cd6d141cb1d176901cd7ee9073c Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 12:25:14 +0100 Subject: [PATCH 05/17] Add test --- tests/integration/annotation_import/test_model_run.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/integration/annotation_import/test_model_run.py b/tests/integration/annotation_import/test_model_run.py index d7c3f925c..697b22c13 100644 --- a/tests/integration/annotation_import/test_model_run.py +++ b/tests/integration/annotation_import/test_model_run.py @@ -114,6 +114,16 @@ def test_model_run_export_labels(model_run_with_model_run_data_rows): assert len(labels) == 3 +def test_model_run_export_labels_v2(model_run_with_model_run_data_rows): + task_name = "test_task" + task = model_run_with_model_run_data_rows.export_labels_v2( + task_name, filter={"media_attributes: true"}) + assert task.name == task_name + task.wait_until_done() + assert task.status == "COMPLETED" + print(task) + + @pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", reason="does not work for onprem") def test_model_run_status(model_run_with_model_run_data_rows): From d5b588a6e854d5dd53f7ee70fcde0ce20287fafe Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 12:34:30 +0100 Subject: [PATCH 06/17] Fix 3.7 v4 --- labelbox/schema/filters.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/labelbox/schema/filters.py b/labelbox/schema/filters.py index bac940e0a..120d24854 100644 --- a/labelbox/schema/filters.py +++ b/labelbox/schema/filters.py @@ -1,9 +1,10 @@ import sys +from typing import Optional if sys.version_info >= (3, 8): - from typing import TypedDict, Optional + from typing import TypedDict else: - from typing_extensions import TypedDict, Optional + from typing_extensions import TypedDict class DataRowFilter(TypedDict): From 304d8e9f919720e4bdcf065d6dbc3a71776dd5b8 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 12:53:45 +0100 Subject: [PATCH 07/17] Fix return type --- labelbox/schema/model_run.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 385dfa559..0d37e3520 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -13,6 +13,8 @@ from labelbox.orm.model import Field, Relationship, Entity from labelbox.orm.db_object import DbObject, experimental from labelbox.schema.filters import ModelRunExportFilter +from labelbox.schema.task import Task +from labelbox.schema.user import User # type: ignore if TYPE_CHECKING: from labelbox import MEAPredictionImport @@ -455,7 +457,7 @@ def export_labels( """ def export_labels_v2(self, task_name: str, - filter: Optional[ModelRunExportFilter]) -> str: + filter: Optional[ModelRunExportFilter]) -> Task: mutation_name = "exportDataRows" create_task_query_str = """mutation exportDataRowsPyApi($input: ExportDataRowsInput!){ %s(input: $input) {taskId} From 3d805562f933d4cff839b7a997ce490f0f47c1dc Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 12:56:48 +0100 Subject: [PATCH 08/17] Fix metadata fields --- labelbox/schema/model_run.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 0d37e3520..431c387ab 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -470,7 +470,7 @@ def export_labels_v2(self, task_name: str, "filters": { modelRunIds: [self.uid] }, - "params": { + "params": { includeAttachments: filter["attachments"] if filter and "attachments" in filter else False, @@ -478,8 +478,8 @@ def export_labels_v2(self, task_name: str, filter['media_attributes'] if filter and 'media_attributes' in filter else False, includeMetadata: - filter['metadata'] - if filter and 'metadata' in filter else False, + filter['metadata_fields'] + if filter and 'metadata_fields' in filter else False, globalIssues: filter["global_issues"] if filter and 'global_issues' in filter else False, From 2f65e22ae59d3e3f10fb4680ab72d086747ab562 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 12:58:54 +0100 Subject: [PATCH 09/17] Formatting --- labelbox/schema/model_run.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 431c387ab..41d841224 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -470,7 +470,7 @@ def export_labels_v2(self, task_name: str, "filters": { modelRunIds: [self.uid] }, - "params": { + "params": { includeAttachments: filter["attachments"] if filter and "attachments" in filter else False, @@ -478,8 +478,8 @@ def export_labels_v2(self, task_name: str, filter['media_attributes'] if filter and 'media_attributes' in filter else False, includeMetadata: - filter['metadata_fields'] - if filter and 'metadata_fields' in filter else False, + filter['metadata_fields'] if filter and + 'metadata_fields' in filter else False, globalIssues: filter["global_issues"] if filter and 'global_issues' in filter else False, From 4fd153e52fe6f275b2bda30255e837ce1a63b7cb Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 13:17:46 +0100 Subject: [PATCH 10/17] Fix task_id ref --- labelbox/schema/model_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 41d841224..22a30a95c 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -496,7 +496,7 @@ def export_labels_v2(self, task_name: str, } }, ) - task_id = res[mutation_name] + res = res[mutation_name] task_id = res["taskId"] user: User = self.client.get_user() tasks: List[Task] = list( From 821097867091ed2445cce6295b53dd7b8f6c3339 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Mon, 23 Jan 2023 14:07:19 +0100 Subject: [PATCH 11/17] Fix payload --- labelbox/schema/model_run.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 22a30a95c..375ba2f9b 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -468,29 +468,29 @@ def export_labels_v2(self, task_name: str, "input": { "taskName": task_name, "filters": { - modelRunIds: [self.uid] + "modelRunIds": [self.uid] }, "params": { - includeAttachments: + "includeAttachments": filter["attachments"] if filter and "attachments" in filter else False, - includeMediaAttributes: + "includeMediaAttributes": filter['media_attributes'] if filter and 'media_attributes' in filter else False, - includeMetadata: + "includeMetadata": filter['metadata_fields'] if filter and 'metadata_fields' in filter else False, - globalIssues: + "globalIssues": filter["global_issues"] if filter and 'global_issues' in filter else False, # Arguments locked based on exectuion context - includeModelRuns: + "includeModelRuns": True, - includeProjectDetails: + "includeProjectDetails": False, - includeLabels: + "includeLabels": False, - includePerformanceDetails: + "includePerformanceDetails": False, }, } From 598db8d27991e74dc5b0d111ff0c624a14711d5d Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Fri, 27 Jan 2023 12:02:15 +0100 Subject: [PATCH 12/17] Remove global filters --- labelbox/schema/filters.py | 1 - labelbox/schema/model_run.py | 63 +++++++++---------- .../annotation_import/test_model_run.py | 1 - 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/labelbox/schema/filters.py b/labelbox/schema/filters.py index 120d24854..a9b93cb8d 100644 --- a/labelbox/schema/filters.py +++ b/labelbox/schema/filters.py @@ -12,7 +12,6 @@ class DataRowFilter(TypedDict): media_attributes: Optional[bool] metadata_fields: Optional[bool] attachments: Optional[bool] - global_issues: Optional[bool] class ProjectExportFilter(DataRowFilter): diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 375ba2f9b..bdf6c4d05 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -460,41 +460,40 @@ def export_labels_v2(self, task_name: str, filter: Optional[ModelRunExportFilter]) -> Task: mutation_name = "exportDataRows" create_task_query_str = """mutation exportDataRowsPyApi($input: ExportDataRowsInput!){ - %s(input: $input) {taskId} + %s(input: $input) {taskId} } """ % (mutation_name) + params = { + "input": { + "taskName": task_name, + "filters": { + "modelRunIds": [self.uid], + "projectIds": [] + }, + "params": { + "includeAttachments": + filter["attachments"] + if filter and "attachments" in filter else False, + "includeMediaAttributes": + filter['media_attributes'] + if filter and 'media_attributes' in filter else False, + "includeMetadata": + filter['metadata_fields'] + if filter and 'metadata_fields' in filter else False, + # Arguments locked based on exectuion context + "includeModelRuns": + True, + "includeProjectDetails": + False, + "includeLabels": + False, + "includePerformanceDetails": + False, + }, + } + } res = self.client.execute( create_task_query_str, - { - "input": { - "taskName": task_name, - "filters": { - "modelRunIds": [self.uid] - }, - "params": { - "includeAttachments": - filter["attachments"] - if filter and "attachments" in filter else False, - "includeMediaAttributes": - filter['media_attributes'] if filter and - 'media_attributes' in filter else False, - "includeMetadata": - filter['metadata_fields'] if filter and - 'metadata_fields' in filter else False, - "globalIssues": - filter["global_issues"] - if filter and 'global_issues' in filter else False, - # Arguments locked based on exectuion context - "includeModelRuns": - True, - "includeProjectDetails": - False, - "includeLabels": - False, - "includePerformanceDetails": - False, - }, - } - }, + params, ) res = res[mutation_name] task_id = res["taskId"] diff --git a/tests/integration/annotation_import/test_model_run.py b/tests/integration/annotation_import/test_model_run.py index 697b22c13..e453260e2 100644 --- a/tests/integration/annotation_import/test_model_run.py +++ b/tests/integration/annotation_import/test_model_run.py @@ -121,7 +121,6 @@ def test_model_run_export_labels_v2(model_run_with_model_run_data_rows): assert task.name == task_name task.wait_until_done() assert task.status == "COMPLETED" - print(task) @pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", From 2347a2b20279c212002a64e2fb135ac5388613f6 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Fri, 27 Jan 2023 22:26:27 +0100 Subject: [PATCH 13/17] Fix test --- labelbox/schema/task.py | 3 +++ tests/integration/annotation_import/test_model_run.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/labelbox/schema/task.py b/labelbox/schema/task.py index bd35261b9..7ce4279e2 100644 --- a/labelbox/schema/task.py +++ b/labelbox/schema/task.py @@ -81,6 +81,9 @@ def wait_till_done(self, timeout_seconds=300) -> None: def errors(self) -> Optional[Dict[str, Any]]: """ Fetch the error associated with an import task. """ + # TODO: We should handle error messages for export v2 tasks in the future. + if self.name != 'JSON Import': + return None if self.status == "FAILED": result = self._fetch_remote_json() return result["error"] diff --git a/tests/integration/annotation_import/test_model_run.py b/tests/integration/annotation_import/test_model_run.py index e453260e2..cc4c80ed5 100644 --- a/tests/integration/annotation_import/test_model_run.py +++ b/tests/integration/annotation_import/test_model_run.py @@ -119,8 +119,8 @@ def test_model_run_export_labels_v2(model_run_with_model_run_data_rows): task = model_run_with_model_run_data_rows.export_labels_v2( task_name, filter={"media_attributes: true"}) assert task.name == task_name - task.wait_until_done() - assert task.status == "COMPLETED" + task.wait_till_done() + assert task.status == "COMPLETE" @pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", From dc7d985ade17cae1a895136c877aa9ffc1fe2293 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Tue, 31 Jan 2023 12:53:01 +0100 Subject: [PATCH 14/17] CR changes --- labelbox/schema/export_params.py | 25 +++++++++++++++++++ labelbox/schema/filters.py | 25 ------------------- labelbox/schema/model_run.py | 16 ++++++------ .../annotation_import/test_model_run.py | 6 +++-- 4 files changed, 37 insertions(+), 35 deletions(-) create mode 100644 labelbox/schema/export_params.py delete mode 100644 labelbox/schema/filters.py diff --git a/labelbox/schema/export_params.py b/labelbox/schema/export_params.py new file mode 100644 index 000000000..4c583e697 --- /dev/null +++ b/labelbox/schema/export_params.py @@ -0,0 +1,25 @@ +import sys + +from typing import Optional +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +class DataRowParams(TypedDict): + include_data_row_details: Optional[bool] + include_media_attributes: Optional[bool] + include_metadata_fields: Optional[bool] + include_attachments: Optional[bool] + + +class ProjectExportParams(DataRowParams): + include_project_details: Optional[bool] + include_label_details: Optional[bool] + include_performance_details: Optional[bool] + + +class ModelRunExportParams(DataRowParams): + # TODO: Add model run fields + pass diff --git a/labelbox/schema/filters.py b/labelbox/schema/filters.py deleted file mode 100644 index a9b93cb8d..000000000 --- a/labelbox/schema/filters.py +++ /dev/null @@ -1,25 +0,0 @@ -import sys - -from typing import Optional -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict - - -class DataRowFilter(TypedDict): - data_row_details: Optional[bool] - media_attributes: Optional[bool] - metadata_fields: Optional[bool] - attachments: Optional[bool] - - -class ProjectExportFilter(DataRowFilter): - project_details: Optional[bool] - label_details: Optional[bool] - performance_details: Optional[bool] - - -class ModelRunExportFilter(DataRowFilter): - # TODO: Add model run fields - pass diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index bdf6c4d05..e9686b8ed 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -12,7 +12,7 @@ from labelbox.orm.query import results_query_part from labelbox.orm.model import Field, Relationship, Entity from labelbox.orm.db_object import DbObject, experimental -from labelbox.schema.filters import ModelRunExportFilter +from labelbox.schema.export_params import ModelRunExportParams from labelbox.schema.task import Task from labelbox.schema.user import User # type: ignore @@ -457,7 +457,8 @@ def export_labels( """ def export_labels_v2(self, task_name: str, - filter: Optional[ModelRunExportFilter]) -> Task: + params: Optional[ModelRunExportParams]) -> Task: + _params = params or {} mutation_name = "exportDataRows" create_task_query_str = """mutation exportDataRowsPyApi($input: ExportDataRowsInput!){ %s(input: $input) {taskId} } @@ -471,14 +472,13 @@ def export_labels_v2(self, task_name: str, }, "params": { "includeAttachments": - filter["attachments"] - if filter and "attachments" in filter else False, + _params.get('include_attachments', False), "includeMediaAttributes": - filter['media_attributes'] - if filter and 'media_attributes' in filter else False, + _params.get('include_media_attributes', False), "includeMetadata": - filter['metadata_fields'] - if filter and 'metadata_fields' in filter else False, + _params.get('include_metadata_fields', False), + "includeDataRowDetails": + _params.get('include_data_row_details', False), # Arguments locked based on exectuion context "includeModelRuns": True, diff --git a/tests/integration/annotation_import/test_model_run.py b/tests/integration/annotation_import/test_model_run.py index cc4c80ed5..20f493c7b 100644 --- a/tests/integration/annotation_import/test_model_run.py +++ b/tests/integration/annotation_import/test_model_run.py @@ -116,11 +116,13 @@ def test_model_run_export_labels(model_run_with_model_run_data_rows): def test_model_run_export_labels_v2(model_run_with_model_run_data_rows): task_name = "test_task" - task = model_run_with_model_run_data_rows.export_labels_v2( - task_name, filter={"media_attributes: true"}) + params = {"media_attributes": True} + task = model_run_with_model_run_data_rows.export_labels_v2(task_name, + params=params) assert task.name == task_name task.wait_till_done() assert task.status == "COMPLETE" + # TODO: Download result and check it @pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", From 32c30f274cfd2ad5efc745cfd5d8ce9db8a2e5af Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Thu, 2 Feb 2023 02:23:52 +0100 Subject: [PATCH 15/17] Update tests --- labelbox/schema/export_params.py | 26 +++++----- labelbox/schema/model_run.py | 19 ++++---- .../annotation_import/test_model_run.py | 48 +++++++++++++++++-- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/labelbox/schema/export_params.py b/labelbox/schema/export_params.py index 4c583e697..5fc9c1cb0 100644 --- a/labelbox/schema/export_params.py +++ b/labelbox/schema/export_params.py @@ -1,25 +1,23 @@ import sys from typing import Optional -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict +from pydantic import BaseModel -class DataRowParams(TypedDict): - include_data_row_details: Optional[bool] - include_media_attributes: Optional[bool] - include_metadata_fields: Optional[bool] - include_attachments: Optional[bool] +class DataRowParams(BaseModel): + include_data_row_details: Optional[bool] = None + include_media_attributes: Optional[bool] = None + include_metadata_fields: Optional[bool] = None + include_attachments: Optional[bool] = None -class ProjectExportParams(DataRowParams): - include_project_details: Optional[bool] - include_label_details: Optional[bool] - include_performance_details: Optional[bool] +class ProjectExportParams(BaseModel): + include_project_details: Optional[bool] = None + include_label_details: Optional[bool] = None + include_performance_details: Optional[bool] = None -class ModelRunExportParams(DataRowParams): + +class ModelRunExportParams(BaseModel): # TODO: Add model run fields pass diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index e9686b8ed..6d6748084 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -14,7 +14,7 @@ from labelbox.orm.db_object import DbObject, experimental from labelbox.schema.export_params import ModelRunExportParams from labelbox.schema.task import Task -from labelbox.schema.user import User # type: ignore +from labelbox.schema.user import User if TYPE_CHECKING: from labelbox import MEAPredictionImport @@ -450,25 +450,24 @@ def export_labels( time.sleep(sleep_time) """ - Creates a model run export task with the given filter and returns the task. + Creates a model run export task with the given params and returns the task. - >>> export_task = export_labels_v2("my_export_task", filter={"media_attributes": True}) + >>> export_task = export_v2("my_export_task", params={"media_attributes": True}) """ - def export_labels_v2(self, task_name: str, - params: Optional[ModelRunExportParams]) -> Task: + def export_v2(self, task_name: str, + params: Optional[ModelRunExportParams]) -> Task: _params = params or {} - mutation_name = "exportDataRows" - create_task_query_str = """mutation exportDataRowsPyApi($input: ExportDataRowsInput!){ + mutation_name = "exportDataRowsInModelRun" + create_task_query_str = """mutation exportDataRowsInModelRunPyApi($input: ExportDataRowsInModelRunInput!){ %s(input: $input) {taskId} } """ % (mutation_name) params = { "input": { "taskName": task_name, "filters": { - "modelRunIds": [self.uid], - "projectIds": [] + "modelRunId": self.uid }, "params": { "includeAttachments": @@ -480,8 +479,6 @@ def export_labels_v2(self, task_name: str, "includeDataRowDetails": _params.get('include_data_row_details', False), # Arguments locked based on exectuion context - "includeModelRuns": - True, "includeProjectDetails": False, "includeLabels": diff --git a/tests/integration/annotation_import/test_model_run.py b/tests/integration/annotation_import/test_model_run.py index 20f493c7b..275719372 100644 --- a/tests/integration/annotation_import/test_model_run.py +++ b/tests/integration/annotation_import/test_model_run.py @@ -1,8 +1,11 @@ +import json import time import os import pytest from collections import Counter + +import requests from labelbox import DataSplit, ModelRun @@ -114,15 +117,50 @@ def test_model_run_export_labels(model_run_with_model_run_data_rows): assert len(labels) == 3 -def test_model_run_export_labels_v2(model_run_with_model_run_data_rows): +def test_model_run_export_v2(model_run_with_model_run_data_rows, + configured_project): task_name = "test_task" - params = {"media_attributes": True} - task = model_run_with_model_run_data_rows.export_labels_v2(task_name, - params=params) + + media_attributes = True + params = {"media_attributes": media_attributes} + task = model_run_with_model_run_data_rows.export_v2(task_name, + params=params) assert task.name == task_name task.wait_till_done() assert task.status == "COMPLETE" - # TODO: Download result and check it + + def download_result(result_url): + response = requests.get(result_url) + response.raise_for_status() + data = [json.loads(line) for line in response.text.splitlines()] + return data + + task_results = download_result(task.result_url) + + label_ids = [label.uid for label in configured_project.labels()] + label_ids_set = set(label_ids) + + assert len(task_results) == len(label_ids) + for task_result in task_results: + assert len(task_result['errors']) == 0 + # Check export param handling + if media_attributes: + assert 'media_attributes' in task_result and task_result[ + 'media_attributes'] is not None + else: + assert 'media_attributes' not in task_result or task_result[ + 'media_attributes'] is None + model_run = task_result['models'][ + model_run_with_model_run_data_rows.model_id]['model_runs'][ + model_run_with_model_run_data_rows.uid] + task_label_ids_set = set( + map(lambda label: label['id'], model_run['labels'])) + task_prediction_ids_set = set( + map(lambda prediction: prediction['id'], model_run['predictions'])) + for label_id in task_label_ids_set: + assert label_id in label_ids_set + for prediction_id in task_prediction_ids_set: + assert prediction_id in label_ids_set @pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", From a708c31148de24e10768cf40a3ad8fae81828062 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Thu, 2 Feb 2023 03:10:15 +0100 Subject: [PATCH 16/17] Rename params --- labelbox/schema/export_params.py | 8 ++++---- labelbox/schema/model_run.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/labelbox/schema/export_params.py b/labelbox/schema/export_params.py index 5fc9c1cb0..459e84b73 100644 --- a/labelbox/schema/export_params.py +++ b/labelbox/schema/export_params.py @@ -6,10 +6,10 @@ class DataRowParams(BaseModel): - include_data_row_details: Optional[bool] = None - include_media_attributes: Optional[bool] = None - include_metadata_fields: Optional[bool] = None - include_attachments: Optional[bool] = None + data_row_details: Optional[bool] = None + media_attributes: Optional[bool] = None + metadata_fields: Optional[bool] = None + attachments: Optional[bool] = None class ProjectExportParams(BaseModel): diff --git a/labelbox/schema/model_run.py b/labelbox/schema/model_run.py index 6d6748084..701fd38b7 100644 --- a/labelbox/schema/model_run.py +++ b/labelbox/schema/model_run.py @@ -471,13 +471,13 @@ def export_v2(self, task_name: str, }, "params": { "includeAttachments": - _params.get('include_attachments', False), + _params.get('attachments', False), "includeMediaAttributes": - _params.get('include_media_attributes', False), + _params.get('media_attributes', False), "includeMetadata": - _params.get('include_metadata_fields', False), + _params.get('metadata_fields', False), "includeDataRowDetails": - _params.get('include_data_row_details', False), + _params.get('data_row_details', False), # Arguments locked based on exectuion context "includeProjectDetails": False, From c9d82653398dd8d8486b08551271a827ae4bb569 Mon Sep 17 00:00:00 2001 From: mnoszczak Date: Thu, 2 Feb 2023 04:01:17 +0100 Subject: [PATCH 17/17] Revert back to TypedDict --- labelbox/schema/export_params.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/labelbox/schema/export_params.py b/labelbox/schema/export_params.py index 459e84b73..234be78d5 100644 --- a/labelbox/schema/export_params.py +++ b/labelbox/schema/export_params.py @@ -1,23 +1,25 @@ import sys from typing import Optional +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict -from pydantic import BaseModel +class DataRowParams(TypedDict): + data_row_details: Optional[bool] + media_attributes: Optional[bool] + metadata_fields: Optional[bool] + attachments: Optional[bool] -class DataRowParams(BaseModel): - data_row_details: Optional[bool] = None - media_attributes: Optional[bool] = None - metadata_fields: Optional[bool] = None - attachments: Optional[bool] = None +class ProjectExportParams(DataRowParams): + project_details: Optional[bool] + label_details: Optional[bool] + performance_details: Optional[bool] -class ProjectExportParams(BaseModel): - include_project_details: Optional[bool] = None - include_label_details: Optional[bool] = None - include_performance_details: Optional[bool] = None - -class ModelRunExportParams(BaseModel): +class ModelRunExportParams(DataRowParams): # TODO: Add model run fields pass