-
Notifications
You must be signed in to change notification settings - Fork 68
[PLT-344] Add dataset.upsert_data_rows method #1460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8951086
4d59492
90e7a4c
456bdba
5a9abe3
ccaf821
1636266
25753e4
74fba0f
b70e650
bb071b8
7e94b45
8c88fe8
28a4435
ea7cd99
8e74651
adb494b
7ec2789
245ccb7
0eaebc9
415f3ad
a3acbaa
9624432
410463d
da4f989
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from typing import Dict, Generator, List, Optional, Union, Any | ||
| from datetime import datetime | ||
| from typing import Dict, Generator, List, Optional, Any, Final | ||
| import os | ||
| import json | ||
| import logging | ||
|
|
@@ -14,18 +15,19 @@ | |
| from io import StringIO | ||
| import requests | ||
|
|
||
| from labelbox import pagination | ||
| from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError | ||
| from labelbox.orm.comparison import Comparison | ||
| from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental | ||
| from labelbox.orm.model import Entity, Field, Relationship | ||
| from labelbox.orm import query | ||
| from labelbox.exceptions import MalformedQueryException | ||
| from labelbox.pagination import PaginatedCollection | ||
| from labelbox.pydantic_compat import BaseModel | ||
| from labelbox.schema.data_row import DataRow | ||
| from labelbox.schema.export_filters import DatasetExportFilters, build_filters | ||
| from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params | ||
| from labelbox.schema.export_task import ExportTask | ||
| from labelbox.schema.identifiable import UniqueId, GlobalKey | ||
| from labelbox.schema.task import Task | ||
| from labelbox.schema.user import User | ||
|
|
||
|
|
@@ -34,6 +36,11 @@ | |
| MAX_DATAROW_PER_API_OPERATION = 150_000 | ||
|
|
||
|
|
||
| class DataRowUpsertItem(BaseModel): | ||
| id: dict | ||
| payload: dict | ||
|
|
||
|
|
||
| class Dataset(DbObject, Updateable, Deletable): | ||
| """ A Dataset is a collection of DataRows. | ||
|
|
||
|
|
@@ -47,6 +54,8 @@ class Dataset(DbObject, Updateable, Deletable): | |
| created_by (Relationship): `ToOne` relationship to User | ||
| organization (Relationship): `ToOne` relationship to Organization | ||
| """ | ||
| __upsert_chunk_size: Final = 10_000 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how was this number chosen?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: In Python, constants are typically written in all capital letters. |
||
|
|
||
| name = Field.String("name") | ||
| description = Field.String("description") | ||
| updated_at = Field.DateTime("updated_at") | ||
|
|
@@ -64,16 +73,16 @@ def data_rows( | |
| from_cursor: Optional[str] = None, | ||
| where: Optional[Comparison] = None, | ||
| ) -> PaginatedCollection: | ||
| """ | ||
| """ | ||
| Custom method to paginate data_rows via cursor. | ||
|
|
||
| Args: | ||
| from_cursor (str): Cursor (data row id) to start from, if none, will start from the beginning | ||
| where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on. | ||
| where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on. | ||
| example: {'external_id': 'my_external_id'} to get a data row with external_id = 'my_external_id' | ||
|
|
||
|
|
||
| NOTE: | ||
| NOTE: | ||
| Order of retrieval is newest data row first. | ||
| Deleted data rows are not retrieved. | ||
| Failed data rows are not retrieved. | ||
|
|
@@ -293,7 +302,10 @@ def create_data_rows(self, items) -> "Task": | |
| task._user = user | ||
| return task | ||
|
|
||
| def _create_descriptor_file(self, items, max_attachments_per_data_row=None): | ||
| def _create_descriptor_file(self, | ||
| items, | ||
| max_attachments_per_data_row=None, | ||
| is_upsert=False): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand you wanted to keep it DRY, but the code is arguably more complex as a result.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will address this in the follow-up work. |
||
| """ | ||
| This function is shared by both `Dataset.create_data_rows` and `Dataset.create_data_rows_sync` | ||
| to prepare the input file. The user defined input is validated, processed, and json stringified. | ||
|
|
@@ -346,6 +358,9 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None): | |
| AssetAttachment = Entity.AssetAttachment | ||
|
|
||
| def upload_if_necessary(item): | ||
| if is_upsert and 'row_data' not in item: | ||
| # When upserting, row_data is not required | ||
| return item | ||
| row_data = item['row_data'] | ||
| if isinstance(row_data, str) and os.path.exists(row_data): | ||
| item_url = self.client.upload_file(row_data) | ||
|
|
@@ -425,17 +440,17 @@ def format_row(item): | |
| return item | ||
|
|
||
| def validate_keys(item): | ||
| if 'row_data' not in item: | ||
| if not is_upsert and 'row_data' not in item: | ||
| raise InvalidQueryError( | ||
| "`row_data` missing when creating DataRow.") | ||
|
|
||
| if isinstance(item.get('row_data'), | ||
| str) and item.get('row_data').startswith("s3:/"): | ||
| raise InvalidQueryError( | ||
| "row_data: s3 assets must start with 'https'.") | ||
| invalid_keys = set(item) - { | ||
| *{f.name for f in DataRow.fields()}, 'attachments', 'media_type' | ||
| } | ||
| allowed_extra_fields = {'attachments', 'media_type', 'dataset_id'} | ||
| invalid_keys = set(item) - {f.name for f in DataRow.fields() | ||
| } - allowed_extra_fields | ||
| if invalid_keys: | ||
| raise InvalidAttributeError(DataRow, invalid_keys) | ||
| return item | ||
|
|
@@ -460,7 +475,12 @@ def formatLegacyConversationalData(item): | |
| item["row_data"] = one_conversation | ||
| return item | ||
|
|
||
| def convert_item(item): | ||
| def convert_item(data_row_item): | ||
| if isinstance(data_row_item, DataRowUpsertItem): | ||
| item = data_row_item.payload | ||
| else: | ||
| item = data_row_item | ||
|
|
||
| if "tileLayerUrl" in item: | ||
| validate_attachments(item) | ||
| return item | ||
|
|
@@ -478,7 +498,11 @@ def convert_item(item): | |
| parse_metadata_fields(item) | ||
| # Upload any local file paths | ||
| item = upload_if_necessary(item) | ||
| return item | ||
|
|
||
| if isinstance(data_row_item, DataRowUpsertItem): | ||
| return {'id': data_row_item.id, 'payload': item} | ||
| else: | ||
| return item | ||
|
|
||
| if not isinstance(items, Iterable): | ||
| raise ValueError( | ||
|
|
@@ -638,13 +662,13 @@ def export_v2( | |
| ) -> Task: | ||
| """ | ||
| Creates a dataset export task with the given params and returns the task. | ||
|
|
||
| >>> dataset = client.get_dataset(DATASET_ID) | ||
| >>> task = dataset.export_v2( | ||
| >>> filters={ | ||
| >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], | ||
| >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], | ||
| >>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...] | ||
| >>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...] | ||
| >>> }, | ||
| >>> params={ | ||
| >>> "performance_details": False, | ||
|
|
@@ -749,3 +773,100 @@ def _export( | |
| res = res[mutation_name] | ||
| task_id = res["taskId"] | ||
| return Task.get_task(self.client, task_id) | ||
|
|
||
| def upsert_data_rows(self, items, file_upload_thread_count=20) -> "Task": | ||
| """ | ||
| Upserts data rows in this dataset. When "key" is provided, and it references an existing data row, | ||
| an update will be performed. When "key" is not provided a new data row will be created. | ||
|
|
||
| >>> task = dataset.upsert_data_rows([ | ||
| >>> # create new data row | ||
| >>> { | ||
| >>> "row_data": "http://my_site.com/photos/img_01.jpg", | ||
| >>> "global_key": "global_key1", | ||
| >>> "external_id": "ex_id1", | ||
| >>> "attachments": [ | ||
| >>> {"type": AttachmentType.RAW_TEXT, "name": "att1", "value": "test1"} | ||
| >>> ], | ||
| >>> "metadata": [ | ||
| >>> {"name": "tag", "value": "tag value"}, | ||
| >>> ] | ||
| >>> }, | ||
| >>> # update global key of data row by existing global key | ||
| >>> { | ||
| >>> "key": GlobalKey("global_key1"), | ||
| >>> "global_key": "global_key1_updated" | ||
| >>> }, | ||
| >>> # update data row by ID | ||
| >>> { | ||
| >>> "key": UniqueId(dr.uid), | ||
| >>> "external_id": "ex_id1_updated" | ||
| >>> }, | ||
| >>> ]) | ||
| >>> task.wait_till_done() | ||
| """ | ||
| if len(items) > MAX_DATAROW_PER_API_OPERATION: | ||
| raise MalformedQueryException( | ||
| f"Cannot upsert more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." | ||
| ) | ||
|
|
||
| specs = self._convert_items_to_upsert_format(items) | ||
| chunks = [ | ||
| specs[i:i + self.__upsert_chunk_size] | ||
| for i in range(0, len(specs), self.__upsert_chunk_size) | ||
| ] | ||
|
|
||
| def _upload_chunk(_chunk): | ||
| return self._create_descriptor_file(_chunk, is_upsert=True) | ||
|
|
||
| with ThreadPoolExecutor(file_upload_thread_count) as executor: | ||
| futures = [ | ||
| executor.submit(_upload_chunk, chunk) for chunk in chunks | ||
| ] | ||
| chunk_uris = [future.result() for future in as_completed(futures)] | ||
|
|
||
| manifest = { | ||
| "source": "SDK", | ||
| "item_count": len(specs), | ||
| "chunk_uris": chunk_uris | ||
| } | ||
| data = json.dumps(manifest).encode("utf-8") | ||
| manifest_uri = self.client.upload_data(data, | ||
| content_type="application/json", | ||
| filename="manifest.json") | ||
|
|
||
| query_str = """ | ||
| mutation UpsertDataRowsPyApi($manifestUri: String!) { | ||
| upsertDataRows(data: { manifestUri: $manifestUri }) { | ||
| id createdAt updatedAt name status completionPercentage result errors type metadata | ||
| } | ||
| } | ||
| """ | ||
|
|
||
| res = self.client.execute(query_str, {"manifestUri": manifest_uri}) | ||
| res = res["upsertDataRows"] | ||
| task = Task(self.client, res) | ||
| task._user = self.client.get_user() | ||
| return task | ||
|
|
||
| def _convert_items_to_upsert_format(self, _items): | ||
| _upsert_items: List[DataRowUpsertItem] = [] | ||
| for item in _items: | ||
| # enforce current dataset's id for all specs | ||
| item['dataset_id'] = self.uid | ||
| key = item.pop('key', None) | ||
| if not key: | ||
| key = {'type': 'AUTO', 'value': ''} | ||
| elif isinstance(key, UniqueId): | ||
| key = {'type': 'ID', 'value': key.key} | ||
| elif isinstance(key, GlobalKey): | ||
| key = {'type': 'GKEY', 'value': key.key} | ||
| else: | ||
| raise ValueError( | ||
| f"Key must be an instance of UniqueId or GlobalKey, got: {type(item['key']).__name__}" | ||
| ) | ||
| item = { | ||
| k: v for k, v in item.items() if v is not None | ||
| } # remove None values | ||
| _upsert_items.append(DataRowUpsertItem(payload=item, id=key)) | ||
| return _upsert_items | ||
Uh oh!
There was an error while loading. Please reload this page.