Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,8 +1559,7 @@ def get_catalog_slice(self, slice_id) -> CatalogSlice:
Returns:
CatalogSlice
"""
query_str = """
query getSavedQueryPyApi($id: ID!) {
query_str = """query getSavedQueryPyApi($id: ID!) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have schema checker that expects the query to be in the first line

getSavedQuery(id: $id) {
id
name
Expand Down
110 changes: 108 additions & 2 deletions labelbox/schema/data_row.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Collection, Dict, List, Optional
import json
from labelbox.exceptions import ResourceNotFoundError

from labelbox.orm import query
from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable
from labelbox.orm.model import Entity, Field, Relationship
from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore
from labelbox.schema.export_params import CatalogExportParams
from labelbox.schema.task import Task
from labelbox.schema.user import User # type: ignore

if TYPE_CHECKING:
from labelbox import AssetAttachment
from labelbox import AssetAttachment, Client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -150,3 +154,105 @@ def create_attachment(self,
})
return Entity.AssetAttachment(self.client,
res["createDataRowAttachment"])

@staticmethod
def export_v2(client: 'Client',
data_rows: List['DataRow'],
task_name: Optional[str] = None,
params: Optional[CatalogExportParams] = None) -> Task:
"""
Creates a data rows export task with the given list, params and returns the task.

>>> dataset = client.get_dataset(DATASET_ID)
>>> task = DataRow.export_v2(
>>> data_rows_ids=[data_row.uid for data_row in dataset.data_rows.list()],
>>> filters={
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"]
>>> },
>>> params={
>>> "performance_details": False,
>>> "label_details": True
>>> })
>>> task.wait_till_done()
>>> task.result
"""
print('export start')

_params = params or CatalogExportParams({
"attachments": False,
"metadata_fields": False,
"data_row_details": False,
"project_details": False,
"performance_details": False,
"label_details": False,
"media_type_override": None,
"model_runs_ids": None,
"projects_ids": None,
})

mutation_name = "exportDataRowsInCatalog"
create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){
%s(input: $input) {taskId} }
""" % (mutation_name)

data_rows_ids = [data_row.uid for data_row in data_rows]
search_query: List[Dict[str, Collection[str]]] = []
search_query.append({
"ids": data_rows_ids,
"operator": "is",
"type": "data_row_id"
})

print(search_query)
media_type_override = _params.get('media_type_override', None)

if task_name is None:
task_name = f"Export v2: data rows (%s)" % len(data_rows_ids)
query_params = {
"input": {
"taskName": task_name,
"filters": {
"searchQuery": {
"scope": None,
"query": search_query
}
},
"params": {
"mediaTypeOverride":
media_type_override.value
if media_type_override is not None else None,
"includeAttachments":
_params.get('attachments', False),
"includeMetadata":
_params.get('metadata_fields', False),
"includeDataRowDetails":
_params.get('data_row_details', False),
"includeProjectDetails":
_params.get('project_details', False),
"includePerformanceDetails":
_params.get('performance_details', False),
"includeLabelDetails":
_params.get('label_details', False)
},
}
}

res = client.execute(
create_task_query_str,
query_params,
)
print(res)
res = res[mutation_name]
task_id = res["taskId"]
user: User = client.get_user()
tasks: List[Task] = list(
user.created_tasks(where=Entity.Task.uid == task_id))
# Cache user in a private variable as the relationship can't be
# resolved due to server-side limitations (see Task.created_by)
# for more info.
if len(tasks) != 1:
raise ResourceNotFoundError(Entity.Task, task_id)
task: Task = tasks[0]
task._user = user
return task
195 changes: 191 additions & 4 deletions labelbox/schema/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generator, List, Union, Any, TYPE_CHECKING
from typing import Collection, Dict, Generator, List, Optional, Union, Any, TYPE_CHECKING
import os
import json
import logging
Expand All @@ -17,9 +17,12 @@
from labelbox.orm.model import Entity, Field, Relationship
from labelbox.orm import query
from labelbox.exceptions import MalformedQueryException

if TYPE_CHECKING:
from labelbox import Task, User, DataRow
from labelbox.schema.data_row import DataRow
from labelbox.schema.export_filters import DatasetExportFilters, SharedExportFilters
from labelbox.schema.export_params import CatalogExportParams
from labelbox.schema.project import _validate_datetime
from labelbox.schema.task import Task
from labelbox.schema.user import User

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -534,3 +537,187 @@ def export_data_rows(self,
logger.debug("Dataset '%s' data row export, waiting for server...",
self.uid)
time.sleep(sleep_time)

def export_v2(self,
task_name: Optional[str] = None,
filters: Optional[DatasetExportFilters] = None,
params: Optional[CatalogExportParams] = None) -> Task:
"""
Creates a dataset export task with the given params and returns the task.

>>> dataset = client.get_dataset(DATASET_ID)
>>> task = dataset.export_v2(
>>> filters={
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"]
>>> },
>>> params={
>>> "performance_details": False,
>>> "label_details": True
>>> })
>>> task.wait_till_done()
>>> task.result
"""

_params = params or CatalogExportParams({
"attachments": False,
"metadata_fields": False,
"data_row_details": False,
"project_details": False,
"performance_details": False,
"label_details": False,
"media_type_override": None,
"model_runs_ids": None,
"projects_ids": None,
})

_filters = filters or DatasetExportFilters({
"last_activity_at": None,
"label_created_at": None
})

def _get_timezone() -> str:
timezone_query_str = """query CurrentUserPyApi { user { timezone } }"""
tz_res = self.client.execute(timezone_query_str)
return tz_res["user"]["timezone"] or "UTC"

timezone: Optional[str] = None

mutation_name = "exportDataRowsInCatalog"
create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){
%s(input: $input) {taskId} }
""" % (mutation_name)

search_query: List[Dict[str, Collection[str]]] = []
search_query.append({
"ids": [self.uid],
"operator": "is",
"type": "dataset"
})
media_type_override = _params.get('media_type_override', None)

if task_name is None:
task_name = f"Export v2: dataset - {self.name}"
query_params = {
"input": {
"taskName": task_name,
"filters": {
"searchQuery": {
"scope": None,
"query": search_query
}
},
"params": {
"mediaTypeOverride":
media_type_override.value
if media_type_override is not None else None,
"includeAttachments":
_params.get('attachments', False),
"includeMetadata":
_params.get('metadata_fields', False),
"includeDataRowDetails":
_params.get('data_row_details', False),
"includeProjectDetails":
_params.get('project_details', False),
"includePerformanceDetails":
_params.get('performance_details', False),
"includeLabelDetails":
_params.get('label_details', False)
},
}
}

if "last_activity_at" in _filters and _filters[
'last_activity_at'] is not None:
if timezone is None:
timezone = _get_timezone()
values = _filters['last_activity_at']
start, end = values
if (start is not None and end is not None):
[_validate_datetime(date) for date in values]
search_query.append({
"type": "data_row_last_activity_at",
"value": {
"operator": "BETWEEN",
"timezone": timezone,
"value": {
"min": start,
"max": end
}
}
})
elif (start is not None):
_validate_datetime(start)
search_query.append({
"type": "data_row_last_activity_at",
"value": {
"operator": "GREATER_THAN_OR_EQUAL",
"timezone": timezone,
"value": start
}
})
elif (end is not None):
_validate_datetime(end)
search_query.append({
"type": "data_row_last_activity_at",
"value": {
"operator": "LESS_THAN_OR_EQUAL",
"timezone": timezone,
"value": end
}
})

if "label_created_at" in _filters and _filters[
"label_created_at"] is not None:
if timezone is None:
timezone = _get_timezone()
values = _filters['label_created_at']
start, end = values
if (start is not None and end is not None):
[_validate_datetime(date) for date in values]
search_query.append({
"type": "labeled_at",
"value": {
"operator": "BETWEEN",
"value": {
"min": start,
"max": end
}
}
})
elif (start is not None):
_validate_datetime(start)
search_query.append({
"type": "labeled_at",
"value": {
"operator": "GREATER_THAN_OR_EQUAL",
"value": start
}
})
elif (end is not None):
_validate_datetime(end)
search_query.append({
"type": "labeled_at",
"value": {
"operator": "LESS_THAN_OR_EQUAL",
"value": end
}
})

res = self.client.execute(
create_task_query_str,
query_params,
)
res = res[mutation_name]
task_id = res["taskId"]
user: User = self.client.get_user()
tasks: List[Task] = list(
user.created_tasks(where=Entity.Task.uid == task_id))
Comment on lines +711 to +715
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minor suggestion: these 3 export_v2 methods could be moved to a separate dedicated service like ExportService. This service could have 3 methods: export_dataset, export_data_row and export_slice. And then these methods could be called in the corresponding classes (Dataset, DataRow and Slice).

This way you would be able to reuse the code parts that are currently duplicated in Dataset.export_v2, DataRow.export_v2 and Slice.export_v2: fetching the task object or building and calling exportDataRowsInCatalog(though it's duplicated only in DataRows and Dataset classes).

Besides reducing code duplication it would increase coherence of the code in ExportService. Also Datasetclass will get to 723 lines long size with this change, putting the most of logic into ExportService would avoid inflating the modules' sizes

Copy link
Contributor Author

@mnoszczak mnoszczak Mar 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my thought as well — will refactor this as part of SDK v4 initiative that aims to use REST endpoints for SDK instead (we'll need backend changes to facilitate this).

# Cache user in a private variable as the relationship can't be
# resolved due to server-side limitations (see Task.created_by)
# for more info.
if len(tasks) != 1:
raise ResourceNotFoundError(Entity.Task, task_id)
task: Task = tasks[0]
task._user = user
return task
10 changes: 9 additions & 1 deletion labelbox/schema/export_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Tuple


class ProjectExportFilters(TypedDict):
class SharedExportFilters(TypedDict):
label_created_at: Optional[Tuple[str, str]]
""" Date range for labels created at
Formatted "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss"
Expand All @@ -26,3 +26,11 @@ class ProjectExportFilters(TypedDict):
>>> [None, "2050-01-01 00:00:00"]
>>> ["2000-01-01 00:00:00", None]
"""


class ProjectExportFilters(SharedExportFilters):
pass


class DatasetExportFilters(SharedExportFilters):
pass
11 changes: 10 additions & 1 deletion labelbox/schema/export_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

from typing import Optional
from typing import Optional, List

from labelbox.schema.media_type import MediaType
if sys.version_info >= (3, 8):
Expand All @@ -22,6 +22,15 @@ class ProjectExportParams(DataRowParams):
performance_details: Optional[bool]


class CatalogExportParams(DataRowParams):
project_details: Optional[bool]
label_details: Optional[bool]
performance_details: Optional[bool]
model_runs_ids: Optional[List[str]]
projects_ids: Optional[List[str]]
pass


class ModelRunExportParams(DataRowParams):
# TODO: Add model run fields
pass
Loading