-
Notifications
You must be signed in to change notification settings - Fork 68
[AL-4398] Add Model Run exports #840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
798f518
18277f1
2d89326
3fa9a9e
b26f235
d5b588a
304d8e9
3d80556
2f65e22
4fd153e
8210978
598db8d
2347a2b
3b065b8
dc7d985
32c30f2
a708c31
c9d8265
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| import sys | ||
|
|
||
| from typing import Optional | ||
| if sys.version_info >= (3, 8): | ||
| from typing import TypedDict | ||
| else: | ||
| from typing_extensions import TypedDict | ||
|
|
||
|
|
||
| class DataRowParams(TypedDict): | ||
| data_row_details: Optional[bool] | ||
| media_attributes: Optional[bool] | ||
| metadata_fields: Optional[bool] | ||
| attachments: Optional[bool] | ||
|
|
||
|
|
||
| class ProjectExportParams(DataRowParams): | ||
| project_details: Optional[bool] | ||
| label_details: Optional[bool] | ||
| performance_details: Optional[bool] | ||
|
|
||
|
|
||
| class ModelRunExportParams(DataRowParams): | ||
| # TODO: Add model run fields | ||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,11 @@ | ||
| import json | ||
| import time | ||
| import os | ||
| import pytest | ||
|
|
||
| from collections import Counter | ||
|
|
||
| import requests | ||
| from labelbox import DataSplit, ModelRun | ||
|
|
||
|
|
||
|
|
@@ -114,6 +117,52 @@ def test_model_run_export_labels(model_run_with_model_run_data_rows): | |
| assert len(labels) == 3 | ||
|
|
||
|
|
||
| def test_model_run_export_v2(model_run_with_model_run_data_rows, | ||
| configured_project): | ||
| task_name = "test_task" | ||
|
|
||
| media_attributes = True | ||
| params = {"media_attributes": media_attributes} | ||
| task = model_run_with_model_run_data_rows.export_v2(task_name, | ||
| params=params) | ||
| assert task.name == task_name | ||
| task.wait_till_done() | ||
| assert task.status == "COMPLETE" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good to check at least that the high level keys exist in the payload and the number of items returned was expected.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't that be in the scope of integration tests of API, not the SDK?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The SDK is the most useful end to end integration test. So it is nice to have it here to guarantee that things are working (also it blocks deployments). But technically it is weird that it serves that purpose. |
||
|
|
||
| def download_result(result_url): | ||
| response = requests.get(result_url) | ||
| response.raise_for_status() | ||
| data = [json.loads(line) for line in response.text.splitlines()] | ||
| return data | ||
|
|
||
| task_results = download_result(task.result_url) | ||
|
|
||
| label_ids = [label.uid for label in configured_project.labels()] | ||
| label_ids_set = set(label_ids) | ||
|
|
||
| assert len(task_results) == len(label_ids) | ||
| for task_result in task_results: | ||
| assert len(task_result['errors']) == 0 | ||
| # Check export param handling | ||
| if media_attributes: | ||
| assert 'media_attributes' in task_result and task_result[ | ||
| 'media_attributes'] is not None | ||
| else: | ||
| assert 'media_attributes' not in task_result or task_result[ | ||
| 'media_attributes'] is None | ||
| model_run = task_result['models'][ | ||
| model_run_with_model_run_data_rows.model_id]['model_runs'][ | ||
| model_run_with_model_run_data_rows.uid] | ||
| task_label_ids_set = set( | ||
| map(lambda label: label['id'], model_run['labels'])) | ||
| task_prediction_ids_set = set( | ||
| map(lambda prediction: prediction['id'], model_run['predictions'])) | ||
| for label_id in task_label_ids_set: | ||
| assert label_id in label_ids_set | ||
| for prediction_id in task_prediction_ids_set: | ||
| assert prediction_id in label_ids_set | ||
|
|
||
|
|
||
| @pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", | ||
| reason="does not work for onprem") | ||
| def test_model_run_status(model_run_with_model_run_data_rows): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.