diff --git a/labelbox/schema/asset_attachment.py b/labelbox/schema/asset_attachment.py index 9ede2b5e4..fba542011 100644 --- a/labelbox/schema/asset_attachment.py +++ b/labelbox/schema/asset_attachment.py @@ -6,6 +6,26 @@ from labelbox.orm.model import Field +class AttachmentType(str, Enum): + + @classmethod + def __missing__(cls, value: object): + if str(value) == "TEXT": + warnings.warn( + "The TEXT attachment type is deprecated. Use RAW_TEXT instead.") + return cls.RAW_TEXT + return value + + VIDEO = "VIDEO" + IMAGE = "IMAGE" + IMAGE_OVERLAY = "IMAGE_OVERLAY" + HTML = "HTML" + RAW_TEXT = "RAW_TEXT" + TEXT_URL = "TEXT_URL" + PDF_URL = "PDF_URL" + CAMERA_IMAGE = "CAMERA_IMAGE" # Used by experimental point-cloud editor + + class AssetAttachment(DbObject): """Asset attachment provides extra context about an asset while labeling. @@ -15,26 +35,6 @@ class AssetAttachment(DbObject): attachment_name (str): The name of the attachment """ - class AttachmentType(Enum): - - @classmethod - def __missing__(cls, value: object): - if str(value) == "TEXT": - warnings.warn( - "The TEXT attachment type is deprecated. Use RAW_TEXT instead." - ) - return cls.RAW_TEXT - return value - - VIDEO = "VIDEO" - IMAGE = "IMAGE" - IMAGE_OVERLAY = "IMAGE_OVERLAY" - HTML = "HTML" - RAW_TEXT = "RAW_TEXT" - TEXT_URL = "TEXT_URL" - PDF_URL = "PDF_URL" - CAMERA_IMAGE = "CAMERA_IMAGE" # Used by experimental point-cloud editor - for topic in AttachmentType: vars()[topic.name] = topic.value @@ -61,7 +61,7 @@ def validate_attachment_value(cls, attachment_value: str) -> None: @classmethod def validate_attachment_type(cls, attachment_type: str) -> None: - valid_types = set(cls.AttachmentType.__members__) + valid_types = set(AttachmentType.__members__) if attachment_type not in valid_types: raise ValueError( f"attachment_type must be one of {valid_types}. Found {attachment_type}" diff --git a/labelbox/schema/data_row.py b/labelbox/schema/data_row.py index 411f78879..1110998ad 100644 --- a/labelbox/schema/data_row.py +++ b/labelbox/schema/data_row.py @@ -1,10 +1,12 @@ import logging -from typing import TYPE_CHECKING, List, Optional, Union +from enum import Enum +from typing import TYPE_CHECKING, List, Optional, Union, Any import json from labelbox.orm import query from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable, experimental from labelbox.orm.model import Entity, Field, Relationship +from labelbox.schema.asset_attachment import AttachmentType from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore from labelbox.schema.export_filters import DatarowExportFilters, build_filters, validate_at_least_one_of_data_row_ids_or_global_keys from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params @@ -17,6 +19,15 @@ logger = logging.getLogger(__name__) +class KeyType(str, Enum): + ID = 'ID' + """An existing CUID""" + GKEY = 'GKEY' + """A Global key, could be existing or non-existing""" + AUTO = 'AUTO' + """The key will be auto-generated. Only usable for creates""" + + class DataRow(DbObject, Updateable, BulkDeletable): """ Internal Labelbox representation of a single piece of data (e.g. image, video, text). @@ -62,7 +73,7 @@ class DataRow(DbObject, Updateable, BulkDeletable): attachments = Relationship.ToMany("AssetAttachment", False, "attachments") supported_meta_types = supported_attachment_types = set( - Entity.AssetAttachment.AttachmentType.__members__) + AttachmentType.__members__) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -131,7 +142,7 @@ def create_attachment(self, Args: attachment_type (str): Asset attachment type, must be one of: - VIDEO, IMAGE, TEXT, IMAGE_OVERLAY (AssetAttachment.AttachmentType) + VIDEO, IMAGE, TEXT, IMAGE_OVERLAY (AttachmentType) attachment_value (str): Asset attachment value. attachment_name (str): (Optional) Asset attachment name. Returns: diff --git a/labelbox/schema/dataset.py b/labelbox/schema/dataset.py index 03dcaaed8..47fa658b7 100644 --- a/labelbox/schema/dataset.py +++ b/labelbox/schema/dataset.py @@ -1,4 +1,5 @@ -from typing import Dict, Generator, List, Optional, Union, Any +from datetime import datetime +from typing import Dict, Generator, List, Optional, Any, Final import os import json import logging @@ -14,7 +15,6 @@ from io import StringIO import requests -from labelbox import pagination from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError from labelbox.orm.comparison import Comparison from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental @@ -22,10 +22,12 @@ from labelbox.orm import query from labelbox.exceptions import MalformedQueryException from labelbox.pagination import PaginatedCollection +from labelbox.pydantic_compat import BaseModel from labelbox.schema.data_row import DataRow from labelbox.schema.export_filters import DatasetExportFilters, build_filters from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params from labelbox.schema.export_task import ExportTask +from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.task import Task from labelbox.schema.user import User @@ -34,6 +36,11 @@ MAX_DATAROW_PER_API_OPERATION = 150_000 +class DataRowUpsertItem(BaseModel): + id: dict + payload: dict + + class Dataset(DbObject, Updateable, Deletable): """ A Dataset is a collection of DataRows. @@ -47,6 +54,8 @@ class Dataset(DbObject, Updateable, Deletable): created_by (Relationship): `ToOne` relationship to User organization (Relationship): `ToOne` relationship to Organization """ + __upsert_chunk_size: Final = 10_000 + name = Field.String("name") description = Field.String("description") updated_at = Field.DateTime("updated_at") @@ -64,16 +73,16 @@ def data_rows( from_cursor: Optional[str] = None, where: Optional[Comparison] = None, ) -> PaginatedCollection: - """ + """ Custom method to paginate data_rows via cursor. Args: from_cursor (str): Cursor (data row id) to start from, if none, will start from the beginning - where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on. + where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on. example: {'external_id': 'my_external_id'} to get a data row with external_id = 'my_external_id' - NOTE: + NOTE: Order of retrieval is newest data row first. Deleted data rows are not retrieved. Failed data rows are not retrieved. @@ -293,7 +302,10 @@ def create_data_rows(self, items) -> "Task": task._user = user return task - def _create_descriptor_file(self, items, max_attachments_per_data_row=None): + def _create_descriptor_file(self, + items, + max_attachments_per_data_row=None, + is_upsert=False): """ This function is shared by both `Dataset.create_data_rows` and `Dataset.create_data_rows_sync` to prepare the input file. The user defined input is validated, processed, and json stringified. @@ -346,6 +358,9 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None): AssetAttachment = Entity.AssetAttachment def upload_if_necessary(item): + if is_upsert and 'row_data' not in item: + # When upserting, row_data is not required + return item row_data = item['row_data'] if isinstance(row_data, str) and os.path.exists(row_data): item_url = self.client.upload_file(row_data) @@ -425,7 +440,7 @@ def format_row(item): return item def validate_keys(item): - if 'row_data' not in item: + if not is_upsert and 'row_data' not in item: raise InvalidQueryError( "`row_data` missing when creating DataRow.") @@ -433,9 +448,9 @@ def validate_keys(item): str) and item.get('row_data').startswith("s3:/"): raise InvalidQueryError( "row_data: s3 assets must start with 'https'.") - invalid_keys = set(item) - { - *{f.name for f in DataRow.fields()}, 'attachments', 'media_type' - } + allowed_extra_fields = {'attachments', 'media_type', 'dataset_id'} + invalid_keys = set(item) - {f.name for f in DataRow.fields() + } - allowed_extra_fields if invalid_keys: raise InvalidAttributeError(DataRow, invalid_keys) return item @@ -460,7 +475,12 @@ def formatLegacyConversationalData(item): item["row_data"] = one_conversation return item - def convert_item(item): + def convert_item(data_row_item): + if isinstance(data_row_item, DataRowUpsertItem): + item = data_row_item.payload + else: + item = data_row_item + if "tileLayerUrl" in item: validate_attachments(item) return item @@ -478,7 +498,11 @@ def convert_item(item): parse_metadata_fields(item) # Upload any local file paths item = upload_if_necessary(item) - return item + + if isinstance(data_row_item, DataRowUpsertItem): + return {'id': data_row_item.id, 'payload': item} + else: + return item if not isinstance(items, Iterable): raise ValueError( @@ -638,13 +662,13 @@ def export_v2( ) -> Task: """ Creates a dataset export task with the given params and returns the task. - + >>> dataset = client.get_dataset(DATASET_ID) >>> task = dataset.export_v2( >>> filters={ >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...] + >>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...] >>> }, >>> params={ >>> "performance_details": False, @@ -749,3 +773,100 @@ def _export( res = res[mutation_name] task_id = res["taskId"] return Task.get_task(self.client, task_id) + + def upsert_data_rows(self, items, file_upload_thread_count=20) -> "Task": + """ + Upserts data rows in this dataset. When "key" is provided, and it references an existing data row, + an update will be performed. When "key" is not provided a new data row will be created. + + >>> task = dataset.upsert_data_rows([ + >>> # create new data row + >>> { + >>> "row_data": "http://my_site.com/photos/img_01.jpg", + >>> "global_key": "global_key1", + >>> "external_id": "ex_id1", + >>> "attachments": [ + >>> {"type": AttachmentType.RAW_TEXT, "name": "att1", "value": "test1"} + >>> ], + >>> "metadata": [ + >>> {"name": "tag", "value": "tag value"}, + >>> ] + >>> }, + >>> # update global key of data row by existing global key + >>> { + >>> "key": GlobalKey("global_key1"), + >>> "global_key": "global_key1_updated" + >>> }, + >>> # update data row by ID + >>> { + >>> "key": UniqueId(dr.uid), + >>> "external_id": "ex_id1_updated" + >>> }, + >>> ]) + >>> task.wait_till_done() + """ + if len(items) > MAX_DATAROW_PER_API_OPERATION: + raise MalformedQueryException( + f"Cannot upsert more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." + ) + + specs = self._convert_items_to_upsert_format(items) + chunks = [ + specs[i:i + self.__upsert_chunk_size] + for i in range(0, len(specs), self.__upsert_chunk_size) + ] + + def _upload_chunk(_chunk): + return self._create_descriptor_file(_chunk, is_upsert=True) + + with ThreadPoolExecutor(file_upload_thread_count) as executor: + futures = [ + executor.submit(_upload_chunk, chunk) for chunk in chunks + ] + chunk_uris = [future.result() for future in as_completed(futures)] + + manifest = { + "source": "SDK", + "item_count": len(specs), + "chunk_uris": chunk_uris + } + data = json.dumps(manifest).encode("utf-8") + manifest_uri = self.client.upload_data(data, + content_type="application/json", + filename="manifest.json") + + query_str = """ + mutation UpsertDataRowsPyApi($manifestUri: String!) { + upsertDataRows(data: { manifestUri: $manifestUri }) { + id createdAt updatedAt name status completionPercentage result errors type metadata + } + } + """ + + res = self.client.execute(query_str, {"manifestUri": manifest_uri}) + res = res["upsertDataRows"] + task = Task(self.client, res) + task._user = self.client.get_user() + return task + + def _convert_items_to_upsert_format(self, _items): + _upsert_items: List[DataRowUpsertItem] = [] + for item in _items: + # enforce current dataset's id for all specs + item['dataset_id'] = self.uid + key = item.pop('key', None) + if not key: + key = {'type': 'AUTO', 'value': ''} + elif isinstance(key, UniqueId): + key = {'type': 'ID', 'value': key.key} + elif isinstance(key, GlobalKey): + key = {'type': 'GKEY', 'value': key.key} + else: + raise ValueError( + f"Key must be an instance of UniqueId or GlobalKey, got: {type(item['key']).__name__}" + ) + item = { + k: v for k, v in item.items() if v is not None + } # remove None values + _upsert_items.append(DataRowUpsertItem(payload=item, id=key)) + return _upsert_items diff --git a/labelbox/schema/task.py b/labelbox/schema/task.py index 8ba91470c..ed2a49a83 100644 --- a/labelbox/schema/task.py +++ b/labelbox/schema/task.py @@ -66,7 +66,7 @@ def wait_till_done(self, Args: timeout_seconds (float): Maximum time this method can block, in seconds. Defaults to five minutes. - check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds. + check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds. """ if check_frequency < 2.0: raise ValueError( @@ -90,7 +90,7 @@ def wait_till_done(self, def errors(self) -> Optional[Dict[str, Any]]: """ Fetch the error associated with an import task. """ - if self.name == 'JSON Import': + if self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows': if self.status == "FAILED": result = self._fetch_remote_json() return result["error"] @@ -168,7 +168,7 @@ def download_result(remote_json_field: Optional[str], format: str): "Expected the result format to be either `ndjson` or `json`." ) - if self.name == 'JSON Import': + if self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows': format = 'json' elif self.type == 'export-data-rows': format = 'ndjson' diff --git a/tests/integration/test_data_rows.py b/tests/integration/test_data_rows.py index 8160a6f9f..672afe85d 100644 --- a/tests/integration/test_data_rows.py +++ b/tests/integration/test_data_rows.py @@ -2,6 +2,7 @@ import uuid from datetime import datetime import json + from labelbox.schema.media_type import MediaType import pytest diff --git a/tests/integration/test_data_rows_upsert.py b/tests/integration/test_data_rows_upsert.py new file mode 100644 index 000000000..2cc893476 --- /dev/null +++ b/tests/integration/test_data_rows_upsert.py @@ -0,0 +1,265 @@ +import json +import uuid +from unittest.mock import patch + +import pytest + +from labelbox.schema.asset_attachment import AttachmentType +from labelbox.schema.identifiable import UniqueId, GlobalKey + + +class TestDataRowUpsert: + + @pytest.fixture + def all_inclusive_data_row(self, dataset, image_url): + dr = dataset.create_data_row( + row_data=image_url, + external_id="ex1", + global_key=str(uuid.uuid4()), + metadata_fields=[{ + "name": "tag", + "value": "tag_string" + }, { + "name": "split", + "value": "train" + }], + attachments=[ + { + "type": "RAW_TEXT", + "name": "att1", + "value": "test1" + }, + { + "type": + "IMAGE", + "name": + "att2", + "value": + "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" + }, + { + "type": + "PDF_URL", + "name": + "att3", + "value": + "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" + }, + ]) + return dr + + def test_create_data_row_with_auto_key(self, dataset, image_url): + task = dataset.upsert_data_rows([{'row_data': image_url}]) + task.wait_till_done() + assert len(list(dataset.data_rows())) == 1 + + def test_create_data_row_with_upsert(self, client, dataset, image_url): + gkey = str(uuid.uuid4()) + task = dataset.upsert_data_rows([{ + 'row_data': + image_url, + 'global_key': + gkey, + 'external_id': + "ex1", + 'attachments': [{ + 'type': AttachmentType.RAW_TEXT, + 'name': "att1", + 'value': "test1" + }, { + 'type': + AttachmentType.IMAGE, + 'name': + "att2", + 'value': + "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" + }, { + 'type': + AttachmentType.PDF_URL, + 'name': + "att3", + 'value': + "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" + }], + 'metadata': [{ + 'name': "tag", + 'value': "updated tag" + }, { + 'name': "split", + 'value': "train" + }] + }]) + task.wait_till_done() + assert task.status == "COMPLETE" + dr = client.get_data_row_by_global_key(gkey) + + assert dr is not None + assert dr.row_data == image_url + assert dr.global_key == gkey + assert dr.external_id == "ex1" + + attachments = list(dr.attachments()) + assert len(attachments) == 3 + assert attachments[0].attachment_name == "att1" + assert attachments[0].attachment_type == AttachmentType.RAW_TEXT + assert attachments[0].attachment_value == "test1" + + assert attachments[1].attachment_name == "att2" + assert attachments[1].attachment_type == AttachmentType.IMAGE + assert attachments[ + 1].attachment_value == "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" + + assert attachments[2].attachment_name == "att3" + assert attachments[2].attachment_type == AttachmentType.PDF_URL + assert attachments[ + 2].attachment_value == "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" + + assert len(dr.metadata_fields) == 2 + assert dr.metadata_fields[0]['name'] == "tag" + assert dr.metadata_fields[0]['value'] == "updated tag" + assert dr.metadata_fields[1]['name'] == "split" + assert dr.metadata_fields[1]['value'] == "train" + + def test_update_data_row_fields_with_upsert(self, client, dataset, + image_url): + gkey = str(uuid.uuid4()) + dr = dataset.create_data_row(row_data=image_url, + external_id="ex1", + global_key=gkey) + task = dataset.upsert_data_rows([{ + 'key': UniqueId(dr.uid), + 'external_id': "ex1_updated", + 'global_key': f"{gkey}_updated" + }]) + task.wait_till_done() + assert task.status == "COMPLETE" + dr = client.get_data_row(dr.uid) + assert dr is not None + assert dr.external_id == "ex1_updated" + assert dr.global_key == f"{gkey}_updated" + + def test_update_data_row_fields_with_upsert_by_global_key( + self, client, dataset, image_url): + gkey = str(uuid.uuid4()) + dr = dataset.create_data_row(row_data=image_url, + external_id="ex1", + global_key=gkey) + task = dataset.upsert_data_rows([{ + 'key': GlobalKey(dr.global_key), + 'external_id': "ex1_updated", + 'global_key': f"{gkey}_updated" + }]) + task.wait_till_done() + assert task.status == "COMPLETE" + dr = client.get_data_row(dr.uid) + assert dr is not None + assert dr.external_id == "ex1_updated" + assert dr.global_key == f"{gkey}_updated" + + def test_update_attachments_with_upsert(self, client, + all_inclusive_data_row, dataset): + dr = all_inclusive_data_row + task = dataset.upsert_data_rows([{ + 'key': + UniqueId(dr.uid), + 'row_data': + dr.row_data, + 'attachments': [{ + 'type': AttachmentType.RAW_TEXT, + 'name': "att1", + 'value': "test" + }] + }]) + task.wait_till_done() + assert task.status == "COMPLETE" + dr = client.get_data_row(dr.uid) + assert dr is not None + attachments = list(dr.attachments()) + assert len(attachments) == 1 + assert attachments[0].attachment_name == "att1" + + def test_update_metadata_with_upsert(self, client, all_inclusive_data_row, + dataset): + dr = all_inclusive_data_row + task = dataset.upsert_data_rows([{ + 'key': + GlobalKey(dr.global_key), + 'row_data': + dr.row_data, + 'metadata': [{ + 'name': "tag", + 'value': "updated tag" + }, { + 'name': "split", + 'value': "train" + }] + }]) + task.wait_till_done() + assert task.status == "COMPLETE" + dr = client.get_data_row(dr.uid) + assert dr is not None + assert len(dr.metadata_fields) == 2 + assert dr.metadata_fields[0]['name'] == "tag" + assert dr.metadata_fields[0]['value'] == "updated tag" + assert dr.metadata_fields[1]['name'] == "split" + assert dr.metadata_fields[1]['value'] == "train" + + def test_multiple_chunks(self, client, dataset, image_url): + mocked_chunk_size = 3 + with patch('labelbox.client.Client.upload_data', + wraps=client.upload_data) as spy_some_function: + with patch( + 'labelbox.schema.dataset.Dataset._Dataset__upsert_chunk_size', + new=mocked_chunk_size): + task = dataset.upsert_data_rows([{ + 'row_data': image_url + } for i in range(10)]) + task.wait_till_done() + assert len(list(dataset.data_rows())) == 10 + assert spy_some_function.call_count == 5 # 4 chunks + manifest + + first_call_args, _ = spy_some_function.call_args_list[0] + first_chunk_content = first_call_args[0] + data = json.loads(first_chunk_content) + assert len(data) == mocked_chunk_size + + last_call_args, _ = spy_some_function.call_args_list[-1] + manifest_content = last_call_args[0].decode('utf-8') + data = json.loads(manifest_content) + assert data['source'] == "SDK" + assert data['item_count'] == 10 + assert len(data['chunk_uris']) == 4 + + def test_upsert_embedded_row_data(self, dataset): + pdf_url = "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/0801.3483.pdf" + task = dataset.upsert_data_rows([{ + 'row_data': { + "pdf_url": + pdf_url, + "text_layer_url": + "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/0801.3483-lb-textlayer.json" + }, + 'media_type': "PDF" + }]) + task.wait_till_done() + data_rows = list(dataset.data_rows()) + assert len(data_rows) == 1 + assert data_rows[0].row_data == pdf_url + + def test_upsert_duplicate_global_key_error(self, dataset, image_url): + gkey = str(uuid.uuid4()) + task = dataset.upsert_data_rows([ + { + 'row_data': image_url, + 'global_key': gkey + }, + { + 'row_data': image_url, + 'global_key': gkey + }, + ]) + task.wait_till_done() + assert task.status == "COMPLETE" + assert task.errors is not None + assert len(task.errors) == 1 # one data row was created, one failed + assert f"Duplicate global key: '{gkey}'" in task.errors[0]['message']