diff --git a/labelbox/schema/asset_attachment.py b/labelbox/schema/asset_attachment.py index 219ebbab8..9ede2b5e4 100644 --- a/labelbox/schema/asset_attachment.py +++ b/labelbox/schema/asset_attachment.py @@ -49,7 +49,15 @@ def validate_attachment_json(cls, attachment_json: Dict[str, str]) -> None: raise ValueError( f"Must provide a `{required_key}` key for each attachment. Found {attachment_json}." ) - cls.validate_attachment_type(attachment_json['type']) + cls.validate_attachment_value(attachment_json['value']) + cls.validate_attachment_type(attachment_json['type']) + + @classmethod + def validate_attachment_value(cls, attachment_value: str) -> None: + if not isinstance(attachment_value, str) or attachment_value == "": + raise ValueError( + f"Attachment value must be a non-empty string, got: '{attachment_value}'" + ) @classmethod def validate_attachment_type(cls, attachment_type: str) -> None: @@ -72,8 +80,20 @@ def update(self, type: Optional[str] = None, value: Optional[str] = None): """Updates an attachment on the data row.""" + if not name and not type and value is None: + raise ValueError( + "At least one of the following must be provided: name, type, value" + ) + + query_params = {"attachment_id": self.uid} if type: self.validate_attachment_type(type) + query_params["type"] = type + if value is not None: + self.validate_attachment_value(value) + query_params["value"] = value + if name: + query_params["name"] = name query_str = """mutation updateDataRowAttachmentPyApi($attachment_id: ID!, $name: String, $type: AttachmentType, $value: String) { updateDataRowAttachment( @@ -81,13 +101,8 @@ def update(self, data: {name: $name, type: $type, value: $value} ) { id name type value } }""" - res = (self.client.execute( - query_str, { - "attachment_id": self.uid, - "name": name, - "type": type, - "value": value - }))['updateDataRowAttachment'] + res = (self.client.execute(query_str, + query_params))['updateDataRowAttachment'] self.attachment_name = res['name'] self.attachment_value = res['value'] diff --git a/labelbox/schema/data_row.py b/labelbox/schema/data_row.py index 9d8758934..411f78879 100644 --- a/labelbox/schema/data_row.py +++ b/labelbox/schema/data_row.py @@ -137,9 +137,13 @@ def create_attachment(self, Returns: `AssetAttachment` DB object. Raises: - ValueError: asset_type must be one of the supported types. + ValueError: attachment_type must be one of the supported types. + ValueError: attachment_value must be a non-empty string. """ - Entity.AssetAttachment.validate_attachment_type(attachment_type) + Entity.AssetAttachment.validate_attachment_json({ + 'type': attachment_type, + 'value': attachment_value + }) attachment_type_param = "type" attachment_value_param = "value" @@ -220,13 +224,13 @@ def export_v2( task_name (str): name of remote task params (CatalogExportParams): export params - + >>> dataset = client.get_dataset(DATASET_ID) >>> task = DataRow.export_v2( - >>> data_rows=[data_row.uid for data_row in dataset.data_rows.list()], + >>> data_rows=[data_row.uid for data_row in dataset.data_rows.list()], >>> # or a list of DataRow objects: data_rows = data_set.data_rows.list() - >>> # or a list of global_keys=["global_key_1", "global_key_2"], - >>> # Note that exactly one of: data_rows or global_keys parameters can be passed in at a time + >>> # or a list of global_keys=["global_key_1", "global_key_2"], + >>> # Note that exactly one of: data_rows or global_keys parameters can be passed in at a time >>> # and if data rows ids is present, global keys will be ignored >>> params={ >>> "performance_details": False, diff --git a/tests/integration/test_data_rows.py b/tests/integration/test_data_rows.py index b9cc4d118..f4b8c337e 100644 --- a/tests/integration/test_data_rows.py +++ b/tests/integration/test_data_rows.py @@ -747,6 +747,26 @@ def test_create_data_rows_sync_mixed_upload(dataset, image_url): assert len(list(dataset.data_rows())) == n_local + n_urls +def test_create_data_row_attachment(data_row): + att = data_row.create_attachment("IMAGE", "https://example.com/image.jpg", + "name") + assert att.attachment_type == "IMAGE" + assert att.attachment_value == "https://example.com/image.jpg" + assert att.attachment_name == "name" + + +def test_create_data_row_attachment_invalid_type(data_row): + with pytest.raises(ValueError): + data_row.create_attachment("SOME_TYPE", "value", "name") + + +def test_create_data_row_attachment_invalid_value(data_row): + with pytest.raises(ValueError): + data_row.create_attachment("IMAGE", "", "name") + with pytest.raises(ValueError): + data_row.create_attachment("IMAGE", None, "name") + + def test_delete_data_row_attachment(data_row, image_url): attachments = [] @@ -791,6 +811,26 @@ def test_update_data_row_attachment_invalid_type(data_row): attachment.update(name="updated name", type="INVALID", value="value") +def test_update_data_row_attachment_invalid_value(data_row): + attachment: AssetAttachment = data_row.create_attachment( + "RAW_TEXT", "value", "name") + assert attachment is not None + with pytest.raises(ValueError): + attachment.update(name="updated name", type="IMAGE", value="") + + +def test_does_not_update_not_provided_attachment_fields(data_row): + attachment: AssetAttachment = data_row.create_attachment( + "RAW_TEXT", "value", "name") + assert attachment is not None + attachment.update(value=None, name="name") + assert attachment.attachment_value == "value" + attachment.update(name=None, value="value") + assert attachment.attachment_name == "name" + attachment.update(type=None, name="name") + assert attachment.attachment_type == "RAW_TEXT" + + def test_create_data_rows_result(client, dataset, image_url): task = dataset.create_data_rows([ {