Skip to content
Merged
25 changes: 25 additions & 0 deletions labelbox/schema/export_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import sys

from typing import Optional
if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict


class DataRowParams(TypedDict):
data_row_details: Optional[bool]
media_attributes: Optional[bool]
metadata_fields: Optional[bool]
attachments: Optional[bool]


class ProjectExportParams(DataRowParams):
project_details: Optional[bool]
label_details: Optional[bool]
performance_details: Optional[bool]


class ModelRunExportParams(DataRowParams):
# TODO: Add model run fields
pass
60 changes: 60 additions & 0 deletions labelbox/schema/model_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from labelbox.orm.query import results_query_part
from labelbox.orm.model import Field, Relationship, Entity
from labelbox.orm.db_object import DbObject, experimental
from labelbox.schema.export_params import ModelRunExportParams
from labelbox.schema.task import Task
from labelbox.schema.user import User

if TYPE_CHECKING:
from labelbox import MEAPredictionImport
Expand Down Expand Up @@ -446,6 +449,63 @@ def export_labels(
self.uid)
time.sleep(sleep_time)

"""
Creates a model run export task with the given params and returns the task.

>>> export_task = export_v2("my_export_task", params={"media_attributes": True})

"""

def export_v2(self, task_name: str,
params: Optional[ModelRunExportParams]) -> Task:
_params = params or {}
mutation_name = "exportDataRowsInModelRun"
create_task_query_str = """mutation exportDataRowsInModelRunPyApi($input: ExportDataRowsInModelRunInput!){
%s(input: $input) {taskId} }
""" % (mutation_name)
params = {
"input": {
"taskName": task_name,
"filters": {
"modelRunId": self.uid
},
"params": {
"includeAttachments":
_params.get('attachments', False),
"includeMediaAttributes":
_params.get('media_attributes', False),
"includeMetadata":
_params.get('metadata_fields', False),
"includeDataRowDetails":
_params.get('data_row_details', False),
# Arguments locked based on exectuion context
"includeProjectDetails":
False,
"includeLabels":
False,
"includePerformanceDetails":
False,
},
}
}
res = self.client.execute(
create_task_query_str,
params,
)
res = res[mutation_name]
task_id = res["taskId"]
user: User = self.client.get_user()
tasks: List[Task] = list(
user.created_tasks(where=Entity.Task.uid == task_id))
# Cache user in a private variable as the relationship can't be
# resolved due to server-side limitations (see Task.created_by)
# for more info.
if len(tasks) != 1:
raise ResourceNotFoundError(Entity.Task, task_id)
task: Task = tasks[0]
task._user = user
return task


class ModelRunDataRow(DbObject):
label_id = Field.String("label_id")
Expand Down
3 changes: 3 additions & 0 deletions labelbox/schema/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def wait_till_done(self, timeout_seconds=300) -> None:
def errors(self) -> Optional[Dict[str, Any]]:
""" Fetch the error associated with an import task.
"""
# TODO: We should handle error messages for export v2 tasks in the future.
if self.name != 'JSON Import':
return None
if self.status == "FAILED":
result = self._fetch_remote_json()
return result["error"]
Expand Down
49 changes: 49 additions & 0 deletions tests/integration/annotation_import/test_model_run.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import time
import os
import pytest

from collections import Counter

import requests
from labelbox import DataSplit, ModelRun


Expand Down Expand Up @@ -114,6 +117,52 @@ def test_model_run_export_labels(model_run_with_model_run_data_rows):
assert len(labels) == 3


def test_model_run_export_v2(model_run_with_model_run_data_rows,
configured_project):
task_name = "test_task"

media_attributes = True
params = {"media_attributes": media_attributes}
task = model_run_with_model_run_data_rows.export_v2(task_name,
params=params)
assert task.name == task_name
task.wait_till_done()
assert task.status == "COMPLETE"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to check at least that the high level keys exist in the payload and the number of items returned was expected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't that be in the scope of integration tests of API, not the SDK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SDK is the most useful end to end integration test. So it is nice to have it here to guarantee that things are working (also it blocks deployments). But technically it is weird that it serves that purpose.


def download_result(result_url):
response = requests.get(result_url)
response.raise_for_status()
data = [json.loads(line) for line in response.text.splitlines()]
return data

task_results = download_result(task.result_url)

label_ids = [label.uid for label in configured_project.labels()]
label_ids_set = set(label_ids)

assert len(task_results) == len(label_ids)
for task_result in task_results:
assert len(task_result['errors']) == 0
# Check export param handling
if media_attributes:
assert 'media_attributes' in task_result and task_result[
'media_attributes'] is not None
else:
assert 'media_attributes' not in task_result or task_result[
'media_attributes'] is None
model_run = task_result['models'][
model_run_with_model_run_data_rows.model_id]['model_runs'][
model_run_with_model_run_data_rows.uid]
task_label_ids_set = set(
map(lambda label: label['id'], model_run['labels']))
task_prediction_ids_set = set(
map(lambda prediction: prediction['id'], model_run['predictions']))
for label_id in task_label_ids_set:
assert label_id in label_ids_set
for prediction_id in task_prediction_ids_set:
assert prediction_id in label_ids_set


@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem",
reason="does not work for onprem")
def test_model_run_status(model_run_with_model_run_data_rows):
Expand Down