-
Notifications
You must be signed in to change notification settings - Fork 68
[AL-5278] [AL-5279] [AL-5260] Add catalog slice, data row, dataset support #1010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1ea2ad3
784a121
4f33e96
cdea274
5a372fc
8437b32
750cf3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import Generator, List, Union, Any, TYPE_CHECKING | ||
from typing import Collection, Dict, Generator, List, Optional, Union, Any, TYPE_CHECKING | ||
import os | ||
import json | ||
import logging | ||
|
@@ -17,9 +17,12 @@ | |
from labelbox.orm.model import Entity, Field, Relationship | ||
from labelbox.orm import query | ||
from labelbox.exceptions import MalformedQueryException | ||
|
||
if TYPE_CHECKING: | ||
from labelbox import Task, User, DataRow | ||
from labelbox.schema.data_row import DataRow | ||
from labelbox.schema.export_filters import DatasetExportFilters, SharedExportFilters | ||
from labelbox.schema.export_params import CatalogExportParams | ||
from labelbox.schema.project import _validate_datetime | ||
from labelbox.schema.task import Task | ||
from labelbox.schema.user import User | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -534,3 +537,187 @@ def export_data_rows(self, | |
logger.debug("Dataset '%s' data row export, waiting for server...", | ||
self.uid) | ||
time.sleep(sleep_time) | ||
|
||
def export_v2(self, | ||
task_name: Optional[str] = None, | ||
filters: Optional[DatasetExportFilters] = None, | ||
params: Optional[CatalogExportParams] = None) -> Task: | ||
""" | ||
Creates a dataset export task with the given params and returns the task. | ||
|
||
>>> dataset = client.get_dataset(DATASET_ID) | ||
>>> task = dataset.export_v2( | ||
>>> filters={ | ||
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], | ||
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"] | ||
>>> }, | ||
>>> params={ | ||
>>> "performance_details": False, | ||
>>> "label_details": True | ||
>>> }) | ||
>>> task.wait_till_done() | ||
>>> task.result | ||
""" | ||
|
||
_params = params or CatalogExportParams({ | ||
"attachments": False, | ||
"metadata_fields": False, | ||
"data_row_details": False, | ||
"project_details": False, | ||
"performance_details": False, | ||
"label_details": False, | ||
"media_type_override": None, | ||
"model_runs_ids": None, | ||
"projects_ids": None, | ||
}) | ||
|
||
_filters = filters or DatasetExportFilters({ | ||
"last_activity_at": None, | ||
"label_created_at": None | ||
}) | ||
|
||
def _get_timezone() -> str: | ||
timezone_query_str = """query CurrentUserPyApi { user { timezone } }""" | ||
tz_res = self.client.execute(timezone_query_str) | ||
return tz_res["user"]["timezone"] or "UTC" | ||
|
||
timezone: Optional[str] = None | ||
|
||
mutation_name = "exportDataRowsInCatalog" | ||
create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){ | ||
%s(input: $input) {taskId} } | ||
""" % (mutation_name) | ||
|
||
search_query: List[Dict[str, Collection[str]]] = [] | ||
search_query.append({ | ||
"ids": [self.uid], | ||
"operator": "is", | ||
"type": "dataset" | ||
}) | ||
media_type_override = _params.get('media_type_override', None) | ||
|
||
if task_name is None: | ||
task_name = f"Export v2: dataset - {self.name}" | ||
query_params = { | ||
"input": { | ||
"taskName": task_name, | ||
"filters": { | ||
"searchQuery": { | ||
"scope": None, | ||
"query": search_query | ||
} | ||
}, | ||
"params": { | ||
"mediaTypeOverride": | ||
media_type_override.value | ||
if media_type_override is not None else None, | ||
"includeAttachments": | ||
_params.get('attachments', False), | ||
"includeMetadata": | ||
_params.get('metadata_fields', False), | ||
"includeDataRowDetails": | ||
_params.get('data_row_details', False), | ||
"includeProjectDetails": | ||
_params.get('project_details', False), | ||
"includePerformanceDetails": | ||
_params.get('performance_details', False), | ||
"includeLabelDetails": | ||
_params.get('label_details', False) | ||
}, | ||
} | ||
} | ||
|
||
if "last_activity_at" in _filters and _filters[ | ||
'last_activity_at'] is not None: | ||
if timezone is None: | ||
timezone = _get_timezone() | ||
values = _filters['last_activity_at'] | ||
start, end = values | ||
if (start is not None and end is not None): | ||
[_validate_datetime(date) for date in values] | ||
search_query.append({ | ||
"type": "data_row_last_activity_at", | ||
"value": { | ||
"operator": "BETWEEN", | ||
"timezone": timezone, | ||
"value": { | ||
"min": start, | ||
"max": end | ||
} | ||
} | ||
}) | ||
elif (start is not None): | ||
_validate_datetime(start) | ||
search_query.append({ | ||
"type": "data_row_last_activity_at", | ||
"value": { | ||
"operator": "GREATER_THAN_OR_EQUAL", | ||
"timezone": timezone, | ||
"value": start | ||
} | ||
}) | ||
elif (end is not None): | ||
_validate_datetime(end) | ||
search_query.append({ | ||
"type": "data_row_last_activity_at", | ||
"value": { | ||
"operator": "LESS_THAN_OR_EQUAL", | ||
"timezone": timezone, | ||
"value": end | ||
} | ||
}) | ||
|
||
if "label_created_at" in _filters and _filters[ | ||
"label_created_at"] is not None: | ||
if timezone is None: | ||
timezone = _get_timezone() | ||
values = _filters['label_created_at'] | ||
start, end = values | ||
if (start is not None and end is not None): | ||
[_validate_datetime(date) for date in values] | ||
search_query.append({ | ||
"type": "labeled_at", | ||
"value": { | ||
"operator": "BETWEEN", | ||
"value": { | ||
"min": start, | ||
"max": end | ||
} | ||
} | ||
}) | ||
elif (start is not None): | ||
_validate_datetime(start) | ||
search_query.append({ | ||
"type": "labeled_at", | ||
"value": { | ||
"operator": "GREATER_THAN_OR_EQUAL", | ||
"value": start | ||
} | ||
}) | ||
elif (end is not None): | ||
_validate_datetime(end) | ||
search_query.append({ | ||
"type": "labeled_at", | ||
"value": { | ||
"operator": "LESS_THAN_OR_EQUAL", | ||
"value": end | ||
} | ||
}) | ||
|
||
res = self.client.execute( | ||
create_task_query_str, | ||
query_params, | ||
) | ||
res = res[mutation_name] | ||
task_id = res["taskId"] | ||
user: User = self.client.get_user() | ||
tasks: List[Task] = list( | ||
user.created_tasks(where=Entity.Task.uid == task_id)) | ||
Comment on lines
+711
to
+715
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A minor suggestion: these 3 This way you would be able to reuse the code parts that are currently duplicated in Besides reducing code duplication it would increase coherence of the code in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That was my thought as well — will refactor this as part of SDK v4 initiative that aims to use REST endpoints for SDK instead (we'll need backend changes to facilitate this). |
||
# Cache user in a private variable as the relationship can't be | ||
# resolved due to server-side limitations (see Task.created_by) | ||
# for more info. | ||
if len(tasks) != 1: | ||
raise ResourceNotFoundError(Entity.Task, task_id) | ||
task: Task = tasks[0] | ||
task._user = user | ||
return task |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have schema checker that expects the query to be in the first line