diff --git a/labelbox/schema/data_row.py b/labelbox/schema/data_row.py index 4c7bb8287..0ec7a4e6e 100644 --- a/labelbox/schema/data_row.py +++ b/labelbox/schema/data_row.py @@ -1,5 +1,6 @@ import logging from typing import TYPE_CHECKING +import json from labelbox.orm import query from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable @@ -64,6 +65,14 @@ def __init__(self, *args, **kwargs): self.attachments.supports_filtering = False self.attachments.supports_sorting = False + def update(self, **kwargs): + # Convert row data to string if it is an object + # All other updates pass through + row_data = kwargs.get("row_data") + if isinstance(row_data, dict): + kwargs['row_data'] = json.dumps(kwargs['row_data']) + super().update(**kwargs) + @staticmethod def bulk_delete(data_rows) -> None: """ Deletes all the given DataRows. diff --git a/labelbox/schema/dataset.py b/labelbox/schema/dataset.py index a2deb62ab..bdfe02e35 100644 --- a/labelbox/schema/dataset.py +++ b/labelbox/schema/dataset.py @@ -15,6 +15,7 @@ from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError from labelbox.orm.db_object import DbObject, Updateable, Deletable from labelbox.orm.model import Entity, Field, Relationship +from labelbox.orm import query from labelbox.exceptions import MalformedQueryException if TYPE_CHECKING: @@ -95,18 +96,46 @@ def convert_field_keys(items): raise InvalidQueryError( "DataRow.row_data missing when creating DataRow.") - # If row data is a local file path, upload it to server. row_data = args[DataRow.row_data.name] - if os.path.exists(row_data): + if not isinstance(row_data, str): + # If the row data is an object, upload as a string + args[DataRow.row_data.name] = json.dumps(row_data) + elif os.path.exists(row_data): + # If row data is a local file path, upload it to server. args[DataRow.row_data.name] = self.client.upload_file(row_data) - args[DataRow.dataset.name] = self # Parse metadata fields, if they are provided if DataRow.metadata_fields.name in args: mdo = self.client.get_data_row_metadata_ontology() args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata( args[DataRow.metadata_fields.name]) - return self.client._create(DataRow, args) + + query_str = """mutation CreateDataRowPyApi( + $row_data: String!, + $metadata_fields: [DataRowCustomMetadataUpsertInput!], + $attachments: [DataRowAttachmentInput!], + $media_type : MediaType, + $external_id : String, + $global_key : String, + $dataset: ID! + ){ + createDataRow( + data: + { + rowData: $row_data + mediaType: $media_type + metadataFields: $metadata_fields + externalId: $external_id + globalKey: $global_key + attachments: $attachments + dataset: {connect: {id: $dataset}} + } + ) + {%s} + } + """ % query.results_query_part(Entity.DataRow) + res = self.client.execute(query_str, {**args, 'dataset': self.uid}) + return DataRow(self.client, res['createDataRow']) def create_data_rows_sync(self, items) -> None: """ Synchronously bulk upload data rows. @@ -229,8 +258,8 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None): >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, >>> {DataRow.row_data:"/path/to/file1.jpg"}, >>> "path/to/file2.jpg", - >>> {"tileLayerUrl" : "http://", ...} - >>> {"conversationalData" : [...], ...} + >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} + >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} >>> ]) For an example showing how to upload tiled data_rows see the following notebook: @@ -258,7 +287,7 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None): def upload_if_necessary(item): row_data = item['row_data'] - if os.path.exists(row_data): + if isinstance(row_data, str) and os.path.exists(row_data): item_url = self.client.upload_file(row_data) item['row_data'] = item_url if 'external_id' not in item: @@ -341,40 +370,39 @@ def validate_keys(item): "`row_data` missing when creating DataRow.") invalid_keys = set(item) - { - *{f.name for f in DataRow.fields()}, 'attachments' + *{f.name for f in DataRow.fields()}, 'attachments', 'media_type' } if invalid_keys: raise InvalidAttributeError(DataRow, invalid_keys) return item + def formatLegacyConversationalData(item): + messages = item.pop("conversationalData") + version = item.pop("version", 1) + type = item.pop("type", "application/vnd.labelbox.conversational") + if "externalId" in item: + external_id = item.pop("externalId") + item["external_id"] = external_id + if "globalKey" in item: + global_key = item.pop("globalKey") + item["globalKey"] = global_key + validate_conversational_data(messages) + one_conversation = \ + { + "type": type, + "version": version, + "messages": messages + } + item["row_data"] = one_conversation + return item + def convert_item(item): - # Don't make any changes to tms data if "tileLayerUrl" in item: validate_attachments(item) return item if "conversationalData" in item: - messages = item.pop("conversationalData") - version = item.pop("version") - type = item.pop("type") - if "externalId" in item: - external_id = item.pop("externalId") - item["external_id"] = external_id - if "globalKey" in item: - global_key = item.pop("globalKey") - item["globalKey"] = global_key - validate_conversational_data(messages) - one_conversation = \ - { - "type": type, - "version": version, - "messages": messages - } - conversationUrl = self.client.upload_data( - json.dumps(one_conversation), - content_type="application/json", - filename="conversational_data.json") - item["row_data"] = conversationUrl + formatLegacyConversationalData(item) # Convert all payload variations into the same dict format item = format_row(item) @@ -386,11 +414,7 @@ def convert_item(item): parse_metadata_fields(item) # Upload any local file paths item = upload_if_necessary(item) - - return { - "data" if key == "row_data" else utils.camel_case(key): value - for key, value in item.items() - } + return item if not isinstance(items, Iterable): raise ValueError( diff --git a/labelbox/schema/media_type.py b/labelbox/schema/media_type.py index c4e139a67..aaddb83be 100644 --- a/labelbox/schema/media_type.py +++ b/labelbox/schema/media_type.py @@ -21,9 +21,9 @@ class MediaType(Enum): @classmethod def _missing_(cls, name): - """Handle missing null data types for projects + """Handle missing null data types for projects created without setting allowedMediaType - Handle upper case names for compatibility with + Handle upper case names for compatibility with the GraphQL""" if name is None: diff --git a/tests/integration/test_data_row_media_attributes.py b/tests/integration/test_data_row_media_attributes.py index d2e1c10b0..e2a594627 100644 --- a/tests/integration/test_data_row_media_attributes.py +++ b/tests/integration/test_data_row_media_attributes.py @@ -7,4 +7,4 @@ def test_export_empty_media_attributes(configured_project_with_label): sleep(10) labels = project.label_generator() label = next(labels) - assert label.data.media_attributes == {} \ No newline at end of file + assert label.data.media_attributes == {} diff --git a/tests/integration/test_data_rows.py b/tests/integration/test_data_rows.py index 9f08df401..d885c4391 100644 --- a/tests/integration/test_data_rows.py +++ b/tests/integration/test_data_rows.py @@ -1,6 +1,7 @@ from tempfile import NamedTemporaryFile import uuid from datetime import datetime +import json import pytest import requests @@ -28,6 +29,56 @@ def mdo(client): yield mdo +@pytest.fixture +def conversational_content(): + return { + 'row_data': { + "messages": [{ + "messageId": "message-0", + "timestampUsec": 1530718491, + "content": "I love iphone! i just bought new iphone! 🥰 📲", + "user": { + "userId": "Bot 002", + "name": "Bot" + }, + "align": "left", + "canLabel": False + }], + "version": 1, + "type": "application/vnd.labelbox.conversational" + } + } + + +@pytest.fixture +def tile_content(): + return { + "row_data": { + "tileLayerUrl": + "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", + "bounds": [[19.405662413477728, -99.21052827588443], + [19.400498983095076, -99.20534818927473]], + "minZoom": + 12, + "maxZoom": + 20, + "epsg": + "EPSG4326", + "alternativeLayers": [{ + "tileLayerUrl": + "https://api.mapbox.com/styles/v1/mapbox/satellite-streets-v11/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", + "name": + "Satellite" + }, { + "tileLayerUrl": + "https://api.mapbox.com/styles/v1/mapbox/navigation-guidance-night-v4/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", + "name": + "Guidance" + }] + } + } + + def make_metadata_fields(): embeddings = [0.0] * 128 msg = "A message" @@ -408,6 +459,18 @@ def test_data_row_update(dataset, rand_gen, image_url): data_row.update(external_id=external_id_2) assert data_row.external_id == external_id_2 + in_line_content = "123" + data_row.update(row_data=in_line_content) + assert requests.get(data_row.row_data).text == in_line_content + + data_row.update(row_data=image_url) + assert data_row.row_data == image_url + + # tileLayer becomes a media attribute + pdf_url = "http://somepdfurl" + data_row.update(row_data={'pdfUrl': pdf_url, "tileLayerUrl": "123"}) + assert data_row.row_data == pdf_url + def test_data_row_filtering_sorting(dataset, image_url): task = dataset.create_data_rows([ @@ -696,3 +759,74 @@ def test_data_row_rulk_creation_sync_with_same_global_keys( assert len(list(dataset.data_rows())) == 1 assert list(dataset.data_rows())[0].global_key == global_key_1 + + +def test_create_conversational_text(dataset, conversational_content): + examples = [ + { + **conversational_content, 'media_type': 'CONVERSATIONAL' + }, + conversational_content, + { + "conversationalData": conversational_content['row_data']['messages'] + } # Old way to check for backwards compatibility + ] + dataset.create_data_rows_sync(examples) + data_rows = list(dataset.data_rows()) + assert len(data_rows) == len(examples) + for data_row in data_rows: + assert requests.get( + data_row.row_data).json() == conversational_content['row_data'] + + +def test_invalid_media_type(dataset, conversational_content): + for error_message, invalid_media_type in [[ + "Found invalid contents for media type: 'IMAGE'", 'IMAGE' + ], ["Found invalid media type: 'totallyinvalid'", 'totallyinvalid']]: + # TODO: What error kind should this be? It looks like for global key we are + # using malformed query. But for invalid contents in FileUploads we use InvalidQueryError + with pytest.raises(labelbox.exceptions.InvalidQueryError): + dataset.create_data_rows_sync([{ + **conversational_content, 'media_type': invalid_media_type + }]) + + task = dataset.create_data_rows([{ + **conversational_content, 'media_type': invalid_media_type + }]) + task.wait_till_done() + assert task.errors == {'message': error_message} + + +def test_create_tiled_layer(dataset, tile_content): + examples = [ + { + **tile_content, 'media_type': 'TMS_SIMPLE' + }, + tile_content, + tile_content['row_data'] # Old way to check for backwards compatibility + ] + dataset.create_data_rows_sync(examples) + data_rows = list(dataset.data_rows()) + assert len(data_rows) == len(examples) + for data_row in data_rows: + assert json.loads(data_row.row_data) == tile_content['row_data'] + + +def test_create_data_row_with_attachments(dataset): + attachment_value = 'attachment value' + dr = dataset.create_data_row(row_data="123", + attachments=[{ + 'type': 'TEXT', + 'value': attachment_value + }]) + attachments = list(dr.attachments()) + assert len(attachments) == 1 + + +def test_create_data_row_with_media_type(dataset, image_url): + with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc: + dr = dataset.create_data_row( + row_data={'invalid_object': 'invalid_value'}, media_type="IMAGE") + assert "Found invalid contents for media type: \'IMAGE\'" in str(exc.value) + + dataset.create_data_row(row_data=image_url, media_type="IMAGE") diff --git a/tests/integration/test_dataset.py b/tests/integration/test_dataset.py index 8237f77bd..89a89b78c 100644 --- a/tests/integration/test_dataset.py +++ b/tests/integration/test_dataset.py @@ -2,7 +2,7 @@ import pytest import requests from labelbox import Dataset -from labelbox.exceptions import ResourceNotFoundError, MalformedQueryException +from labelbox.exceptions import ResourceNotFoundError, MalformedQueryException, InvalidQueryError from labelbox.schema.dataset import MAX_DATAROW_PER_API_OPERATION @@ -103,6 +103,33 @@ def test_upload_video_file(dataset, sample_video: str) -> None: assert response.headers['Content-Type'] == 'video/mp4' +def test_create_pdf(dataset): + dataset.create_data_row( + row_data={ + "pdfUrl": + "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", + "textLayerUrl": + "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" + }) + dataset.create_data_row(row_data={ + "pdfUrl": + "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", + "textLayerUrl": + "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" + }, + media_type="PDF") + + with pytest.raises(InvalidQueryError): + # Wrong media type + dataset.create_data_row(row_data={ + "pdfUrl": + "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", + "textLayerUrl": + "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" + }, + media_type="TEXT") + + def test_bulk_conversation(dataset, sample_bulk_conversation: list) -> None: """ Tests that bulk conversations can be uploaded. @@ -133,7 +160,7 @@ def test_create_descriptor_file(dataset): upload_data_spy.assert_called() call_args, call_kwargs = upload_data_spy.call_args_list[0][ 0], upload_data_spy.call_args_list[0][1] - assert call_args == ('[{"data": "some text..."}]',) + assert call_args == ('[{"row_data": "some text..."}]',) assert call_kwargs == { 'content_type': 'application/json', 'filename': 'json_import.json'