diff --git a/labelbox/schema/asset_attachment.py b/labelbox/schema/asset_attachment.py index 12a658ab3..219ebbab8 100644 --- a/labelbox/schema/asset_attachment.py +++ b/labelbox/schema/asset_attachment.py @@ -1,6 +1,6 @@ import warnings from enum import Enum -from typing import Dict +from typing import Dict, Optional from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field @@ -12,6 +12,7 @@ class AssetAttachment(DbObject): Attributes: attachment_type (str): IMAGE, VIDEO, IMAGE_OVERLAY, HTML, RAW_TEXT, TEXT_URL, or PDF_URL. TEXT attachment type is deprecated. attachment_value (str): URL to an external file or a string of text + attachment_name (str): The name of the attachment """ class AttachmentType(Enum): @@ -27,7 +28,6 @@ def __missing__(cls, value: object): VIDEO = "VIDEO" IMAGE = "IMAGE" - # TEXT = "TEXT" # Deprecated IMAGE_OVERLAY = "IMAGE_OVERLAY" HTML = "HTML" RAW_TEXT = "RAW_TEXT" @@ -40,6 +40,7 @@ def __missing__(cls, value: object): attachment_type = Field.String("attachment_type", "type") attachment_value = Field.String("attachment_value", "value") + attachment_name = Field.String("attachment_name", "name") @classmethod def validate_attachment_json(cls, attachment_json: Dict[str, str]) -> None: @@ -55,7 +56,7 @@ def validate_attachment_type(cls, attachment_type: str) -> None: valid_types = set(cls.AttachmentType.__members__) if attachment_type not in valid_types: raise ValueError( - f"meta_type must be one of {valid_types}. Found {attachment_type}" + f"attachment_type must be one of {valid_types}. Found {attachment_type}" ) def delete(self) -> None: @@ -65,3 +66,29 @@ def delete(self) -> None: id} }""" self.client.execute(query_str, {"attachment_id": self.uid}) + + def update(self, + name: Optional[str] = None, + type: Optional[str] = None, + value: Optional[str] = None): + """Updates an attachment on the data row.""" + if type: + self.validate_attachment_type(type) + + query_str = """mutation updateDataRowAttachmentPyApi($attachment_id: ID!, $name: String, $type: AttachmentType, $value: String) { + updateDataRowAttachment( + where: {id: $attachment_id}, + data: {name: $name, type: $type, value: $value} + ) { id name type value } + }""" + res = (self.client.execute( + query_str, { + "attachment_id": self.uid, + "name": name, + "type": type, + "value": value + }))['updateDataRowAttachment'] + + self.attachment_name = res['name'] + self.attachment_value = res['value'] + self.attachment_type = res['type'] diff --git a/tests/integration/test_data_rows.py b/tests/integration/test_data_rows.py index deea19e6b..b9cc4d118 100644 --- a/tests/integration/test_data_rows.py +++ b/tests/integration/test_data_rows.py @@ -7,7 +7,7 @@ import pytest import requests -from labelbox import DataRow +from labelbox import DataRow, AssetAttachment from labelbox.exceptions import MalformedQueryException from labelbox.schema.task import Task from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadataKind @@ -773,6 +773,24 @@ def test_delete_data_row_attachment(data_row, image_url): assert len(list(data_row.attachments())) == 0 +def test_update_data_row_attachment(data_row, image_url): + attachment: AssetAttachment = data_row.create_attachment( + "RAW_TEXT", "value", "name") + assert attachment is not None + attachment.update(name="updated name", type="IMAGE", value=image_url) + assert attachment.attachment_name == "updated name" + assert attachment.attachment_type == "IMAGE" + assert attachment.attachment_value == image_url + + +def test_update_data_row_attachment_invalid_type(data_row): + attachment: AssetAttachment = data_row.create_attachment( + "RAW_TEXT", "value", "name") + assert attachment is not None + with pytest.raises(ValueError): + attachment.update(name="updated name", type="INVALID", value="value") + + def test_create_data_rows_result(client, dataset, image_url): task = dataset.create_data_rows([ {