Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8951086
add dataset.upsert_data_rows method
attila-papai Mar 7, 2024
4d59492
add DataRowSpec class, update return type of upsertDataRows
attila-papai Mar 12, 2024
90e7a4c
fix item validation for upsert
attila-papai Mar 14, 2024
456bdba
improve dataset.upsert_data_rows method and add more tests
attila-papai Mar 19, 2024
5a9abe3
Merge branch 'develop' into attila/PLT-344-upsert-data-rows
attila-papai Mar 19, 2024
ccaf821
improve dataset.upsert_data_rows method and add more tests
attila-papai Mar 21, 2024
1636266
adjust code to backend changes
attila-papai Mar 21, 2024
25753e4
add upsert chunk size constant
attila-papai Mar 21, 2024
74fba0f
add test for multiple chunks
attila-papai Mar 21, 2024
b70e650
mypy fix
attila-papai Mar 21, 2024
bb071b8
mypy fix
attila-papai Mar 21, 2024
7e94b45
exclude None from json
attila-papai Mar 22, 2024
8c88fe8
finalizing improvements
attila-papai Mar 22, 2024
28a4435
add media_type to DataRowSpec with a test
attila-papai Mar 22, 2024
ea7cd99
add comment
attila-papai Mar 22, 2024
8e74651
add test for errors checking
attila-papai Mar 25, 2024
adb494b
mangle chunk size constant
attila-papai Mar 25, 2024
7ec2789
upload chunks in parallel
attila-papai Mar 25, 2024
245ccb7
Merge branch 'develop' into attila/PLT-344-upsert-data-rows
attila-papai Mar 25, 2024
0eaebc9
fix mypy
attila-papai Mar 25, 2024
415f3ad
Merge branch 'attila/PLT-344-upsert-data-rows' of github.com:Labelbox…
attila-papai Mar 25, 2024
a3acbaa
remove pydantic models
attila-papai Mar 26, 2024
9624432
use unique global keys in tests
attila-papai Mar 26, 2024
410463d
improvements 2
attila-papai Mar 27, 2024
da4f989
Merge branch 'develop' into attila/PLT-344-upsert-data-rows
manuaero Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions labelbox/schema/asset_attachment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,26 @@
from labelbox.orm.model import Field


class AttachmentType(str, Enum):

@classmethod
def __missing__(cls, value: object):
if str(value) == "TEXT":
warnings.warn(
"The TEXT attachment type is deprecated. Use RAW_TEXT instead.")
return cls.RAW_TEXT
return value

VIDEO = "VIDEO"
IMAGE = "IMAGE"
IMAGE_OVERLAY = "IMAGE_OVERLAY"
HTML = "HTML"
RAW_TEXT = "RAW_TEXT"
TEXT_URL = "TEXT_URL"
PDF_URL = "PDF_URL"
CAMERA_IMAGE = "CAMERA_IMAGE" # Used by experimental point-cloud editor


class AssetAttachment(DbObject):
"""Asset attachment provides extra context about an asset while labeling.

Expand All @@ -15,26 +35,6 @@ class AssetAttachment(DbObject):
attachment_name (str): The name of the attachment
"""

class AttachmentType(Enum):

@classmethod
def __missing__(cls, value: object):
if str(value) == "TEXT":
warnings.warn(
"The TEXT attachment type is deprecated. Use RAW_TEXT instead."
)
return cls.RAW_TEXT
return value

VIDEO = "VIDEO"
IMAGE = "IMAGE"
IMAGE_OVERLAY = "IMAGE_OVERLAY"
HTML = "HTML"
RAW_TEXT = "RAW_TEXT"
TEXT_URL = "TEXT_URL"
PDF_URL = "PDF_URL"
CAMERA_IMAGE = "CAMERA_IMAGE" # Used by experimental point-cloud editor

for topic in AttachmentType:
vars()[topic.name] = topic.value

Expand All @@ -61,7 +61,7 @@ def validate_attachment_value(cls, attachment_value: str) -> None:

@classmethod
def validate_attachment_type(cls, attachment_type: str) -> None:
valid_types = set(cls.AttachmentType.__members__)
valid_types = set(AttachmentType.__members__)
if attachment_type not in valid_types:
raise ValueError(
f"attachment_type must be one of {valid_types}. Found {attachment_type}"
Expand Down
17 changes: 14 additions & 3 deletions labelbox/schema/data_row.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Union
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union, Any
import json

from labelbox.orm import query
from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable, experimental
from labelbox.orm.model import Entity, Field, Relationship
from labelbox.schema.asset_attachment import AttachmentType
from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore
from labelbox.schema.export_filters import DatarowExportFilters, build_filters, validate_at_least_one_of_data_row_ids_or_global_keys
from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params
Expand All @@ -17,6 +19,15 @@
logger = logging.getLogger(__name__)


class KeyType(str, Enum):
ID = 'ID'
"""An existing CUID"""
GKEY = 'GKEY'
"""A Global key, could be existing or non-existing"""
AUTO = 'AUTO'
"""The key will be auto-generated. Only usable for creates"""


class DataRow(DbObject, Updateable, BulkDeletable):
""" Internal Labelbox representation of a single piece of data (e.g. image, video, text).

Expand Down Expand Up @@ -62,7 +73,7 @@ class DataRow(DbObject, Updateable, BulkDeletable):
attachments = Relationship.ToMany("AssetAttachment", False, "attachments")

supported_meta_types = supported_attachment_types = set(
Entity.AssetAttachment.AttachmentType.__members__)
AttachmentType.__members__)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -131,7 +142,7 @@ def create_attachment(self,

Args:
attachment_type (str): Asset attachment type, must be one of:
VIDEO, IMAGE, TEXT, IMAGE_OVERLAY (AssetAttachment.AttachmentType)
VIDEO, IMAGE, TEXT, IMAGE_OVERLAY (AttachmentType)
attachment_value (str): Asset attachment value.
attachment_name (str): (Optional) Asset attachment name.
Returns:
Expand Down
149 changes: 135 additions & 14 deletions labelbox/schema/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Generator, List, Optional, Union, Any
from datetime import datetime
from typing import Dict, Generator, List, Optional, Any, Final
import os
import json
import logging
Expand All @@ -14,18 +15,19 @@
from io import StringIO
import requests

from labelbox import pagination
from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError
from labelbox.orm.comparison import Comparison
from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental
from labelbox.orm.model import Entity, Field, Relationship
from labelbox.orm import query
from labelbox.exceptions import MalformedQueryException
from labelbox.pagination import PaginatedCollection
from labelbox.pydantic_compat import BaseModel
from labelbox.schema.data_row import DataRow
from labelbox.schema.export_filters import DatasetExportFilters, build_filters
from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params
from labelbox.schema.export_task import ExportTask
from labelbox.schema.identifiable import UniqueId, GlobalKey
from labelbox.schema.task import Task
from labelbox.schema.user import User

Expand All @@ -34,6 +36,11 @@
MAX_DATAROW_PER_API_OPERATION = 150_000


class DataRowUpsertItem(BaseModel):
id: dict
payload: dict


class Dataset(DbObject, Updateable, Deletable):
""" A Dataset is a collection of DataRows.

Expand All @@ -47,6 +54,8 @@ class Dataset(DbObject, Updateable, Deletable):
created_by (Relationship): `ToOne` relationship to User
organization (Relationship): `ToOne` relationship to Organization
"""
__upsert_chunk_size: Final = 10_000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how was this number chosen?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: In Python, constants are typically written in all capital letters.


name = Field.String("name")
description = Field.String("description")
updated_at = Field.DateTime("updated_at")
Expand All @@ -64,16 +73,16 @@ def data_rows(
from_cursor: Optional[str] = None,
where: Optional[Comparison] = None,
) -> PaginatedCollection:
"""
"""
Custom method to paginate data_rows via cursor.

Args:
from_cursor (str): Cursor (data row id) to start from, if none, will start from the beginning
where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on.
where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on.
example: {'external_id': 'my_external_id'} to get a data row with external_id = 'my_external_id'


NOTE:
NOTE:
Order of retrieval is newest data row first.
Deleted data rows are not retrieved.
Failed data rows are not retrieved.
Expand Down Expand Up @@ -293,7 +302,10 @@ def create_data_rows(self, items) -> "Task":
task._user = user
return task

def _create_descriptor_file(self, items, max_attachments_per_data_row=None):
def _create_descriptor_file(self,
items,
max_attachments_per_data_row=None,
is_upsert=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand you wanted to keep it DRY, but the code is arguably more complex as a result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will address this in the follow-up work.

"""
This function is shared by both `Dataset.create_data_rows` and `Dataset.create_data_rows_sync`
to prepare the input file. The user defined input is validated, processed, and json stringified.
Expand Down Expand Up @@ -346,6 +358,9 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None):
AssetAttachment = Entity.AssetAttachment

def upload_if_necessary(item):
if is_upsert and 'row_data' not in item:
# When upserting, row_data is not required
return item
row_data = item['row_data']
if isinstance(row_data, str) and os.path.exists(row_data):
item_url = self.client.upload_file(row_data)
Expand Down Expand Up @@ -425,17 +440,17 @@ def format_row(item):
return item

def validate_keys(item):
if 'row_data' not in item:
if not is_upsert and 'row_data' not in item:
raise InvalidQueryError(
"`row_data` missing when creating DataRow.")

if isinstance(item.get('row_data'),
str) and item.get('row_data').startswith("s3:/"):
raise InvalidQueryError(
"row_data: s3 assets must start with 'https'.")
invalid_keys = set(item) - {
*{f.name for f in DataRow.fields()}, 'attachments', 'media_type'
}
allowed_extra_fields = {'attachments', 'media_type', 'dataset_id'}
invalid_keys = set(item) - {f.name for f in DataRow.fields()
} - allowed_extra_fields
if invalid_keys:
raise InvalidAttributeError(DataRow, invalid_keys)
return item
Expand All @@ -460,7 +475,12 @@ def formatLegacyConversationalData(item):
item["row_data"] = one_conversation
return item

def convert_item(item):
def convert_item(data_row_item):
if isinstance(data_row_item, DataRowUpsertItem):
item = data_row_item.payload
else:
item = data_row_item

if "tileLayerUrl" in item:
validate_attachments(item)
return item
Expand All @@ -478,7 +498,11 @@ def convert_item(item):
parse_metadata_fields(item)
# Upload any local file paths
item = upload_if_necessary(item)
return item

if isinstance(data_row_item, DataRowUpsertItem):
return {'id': data_row_item.id, 'payload': item}
else:
return item

if not isinstance(items, Iterable):
raise ValueError(
Expand Down Expand Up @@ -638,13 +662,13 @@ def export_v2(
) -> Task:
"""
Creates a dataset export task with the given params and returns the task.

>>> dataset = client.get_dataset(DATASET_ID)
>>> task = dataset.export_v2(
>>> filters={
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
>>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...]
>>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...]
>>> },
>>> params={
>>> "performance_details": False,
Expand Down Expand Up @@ -749,3 +773,100 @@ def _export(
res = res[mutation_name]
task_id = res["taskId"]
return Task.get_task(self.client, task_id)

def upsert_data_rows(self, items, file_upload_thread_count=20) -> "Task":
"""
Upserts data rows in this dataset. When "key" is provided, and it references an existing data row,
an update will be performed. When "key" is not provided a new data row will be created.

>>> task = dataset.upsert_data_rows([
>>> # create new data row
>>> {
>>> "row_data": "http://my_site.com/photos/img_01.jpg",
>>> "global_key": "global_key1",
>>> "external_id": "ex_id1",
>>> "attachments": [
>>> {"type": AttachmentType.RAW_TEXT, "name": "att1", "value": "test1"}
>>> ],
>>> "metadata": [
>>> {"name": "tag", "value": "tag value"},
>>> ]
>>> },
>>> # update global key of data row by existing global key
>>> {
>>> "key": GlobalKey("global_key1"),
>>> "global_key": "global_key1_updated"
>>> },
>>> # update data row by ID
>>> {
>>> "key": UniqueId(dr.uid),
>>> "external_id": "ex_id1_updated"
>>> },
>>> ])
>>> task.wait_till_done()
"""
if len(items) > MAX_DATAROW_PER_API_OPERATION:
raise MalformedQueryException(
f"Cannot upsert more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call."
)

specs = self._convert_items_to_upsert_format(items)
chunks = [
specs[i:i + self.__upsert_chunk_size]
for i in range(0, len(specs), self.__upsert_chunk_size)
]

def _upload_chunk(_chunk):
return self._create_descriptor_file(_chunk, is_upsert=True)

with ThreadPoolExecutor(file_upload_thread_count) as executor:
futures = [
executor.submit(_upload_chunk, chunk) for chunk in chunks
]
chunk_uris = [future.result() for future in as_completed(futures)]

manifest = {
"source": "SDK",
"item_count": len(specs),
"chunk_uris": chunk_uris
}
data = json.dumps(manifest).encode("utf-8")
manifest_uri = self.client.upload_data(data,
content_type="application/json",
filename="manifest.json")

query_str = """
mutation UpsertDataRowsPyApi($manifestUri: String!) {
upsertDataRows(data: { manifestUri: $manifestUri }) {
id createdAt updatedAt name status completionPercentage result errors type metadata
}
}
"""

res = self.client.execute(query_str, {"manifestUri": manifest_uri})
res = res["upsertDataRows"]
task = Task(self.client, res)
task._user = self.client.get_user()
return task

def _convert_items_to_upsert_format(self, _items):
_upsert_items: List[DataRowUpsertItem] = []
for item in _items:
# enforce current dataset's id for all specs
item['dataset_id'] = self.uid
key = item.pop('key', None)
if not key:
key = {'type': 'AUTO', 'value': ''}
elif isinstance(key, UniqueId):
key = {'type': 'ID', 'value': key.key}
elif isinstance(key, GlobalKey):
key = {'type': 'GKEY', 'value': key.key}
else:
raise ValueError(
f"Key must be an instance of UniqueId or GlobalKey, got: {type(item['key']).__name__}"
)
item = {
k: v for k, v in item.items() if v is not None
} # remove None values
_upsert_items.append(DataRowUpsertItem(payload=item, id=key))
return _upsert_items
6 changes: 3 additions & 3 deletions labelbox/schema/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def wait_till_done(self,

Args:
timeout_seconds (float): Maximum time this method can block, in seconds. Defaults to five minutes.
check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds.
check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds.
"""
if check_frequency < 2.0:
raise ValueError(
Expand All @@ -90,7 +90,7 @@ def wait_till_done(self,
def errors(self) -> Optional[Dict[str, Any]]:
""" Fetch the error associated with an import task.
"""
if self.name == 'JSON Import':
if self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows':
if self.status == "FAILED":
result = self._fetch_remote_json()
return result["error"]
Expand Down Expand Up @@ -168,7 +168,7 @@ def download_result(remote_json_field: Optional[str], format: str):
"Expected the result format to be either `ndjson` or `json`."
)

if self.name == 'JSON Import':
if self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows':
format = 'json'
elif self.type == 'export-data-rows':
format = 'ndjson'
Expand Down
Loading