diff --git a/labelbox/schema/batch.py b/labelbox/schema/batch.py index 0f7da0d32..684ec60e9 100644 --- a/labelbox/schema/batch.py +++ b/labelbox/schema/batch.py @@ -104,9 +104,9 @@ def export_data_rows(self, timeout_seconds=120) -> Generator: response = requests.get(download_url) response.raise_for_status() reader = ndjson.reader(StringIO(response.text)) - # TODO: Update result to parse customMetadata when resolver returns + # TODO: Update result to parse metadataFields when resolver returns return (Entity.DataRow(self.client, { - **result, 'customMetadata': [] + **result, 'metadataFields': [] }) for result in reader) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") diff --git a/labelbox/schema/data_row.py b/labelbox/schema/data_row.py index 7a032566d..01b803ccc 100644 --- a/labelbox/schema/data_row.py +++ b/labelbox/schema/data_row.py @@ -22,7 +22,7 @@ class DataRow(DbObject, Updateable, BulkDeletable): updated_at (datetime) created_at (datetime) media_attributes (dict): generated media attributes for the datarow - custom_metadata (list): metadata associated with the datarow + metadata_fields (list): metadata associated with the datarow dataset (Relationship): `ToOne` relationship to Dataset created_by (Relationship): `ToOne` relationship to User @@ -35,11 +35,11 @@ class DataRow(DbObject, Updateable, BulkDeletable): updated_at = Field.DateTime("updated_at") created_at = Field.DateTime("created_at") media_attributes = Field.Json("media_attributes") - custom_metadata = Field.List( + metadata_fields = Field.List( DataRowMetadataField, graphql_type="DataRowCustomMetadataUpsertInput!", - name="custom_metadata", - result_subquery="customMetadata { value schemaId }") + name="metadata_fields", + result_subquery="metadataFields { schemaId name value kind }") # Relationships dataset = Relationship.ToOne("Dataset") diff --git a/labelbox/schema/dataset.py b/labelbox/schema/dataset.py index 3570f263a..33345383e 100644 --- a/labelbox/schema/dataset.py +++ b/labelbox/schema/dataset.py @@ -52,40 +52,57 @@ class Dataset(DbObject, Updateable, Deletable): iam_integration = Relationship.ToOne("IAMIntegration", False, "iam_integration", "signer") - def create_data_row(self, **kwargs) -> "DataRow": + def create_data_row(self, items=None, **kwargs) -> "DataRow": """ Creates a single DataRow belonging to this dataset. >>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg") Args: + items: Dictionary containing new `DataRow` data. At a minimum, + must contain `row_data` or `DataRow.row_data`. **kwargs: Key-value arguments containing new `DataRow` data. At a minimum, must contain `row_data`. Raises: + InvalidQueryError: If both dictionary and `kwargs` are provided as inputs InvalidQueryError: If `DataRow.row_data` field value is not provided in `kwargs`. InvalidAttributeError: in case the DB object type does not contain any of the field names given in `kwargs`. """ + invalid_argument_error = "Argument to create_data_row() must be either a dictionary, or kwargs containing `row_data` at minimum" + + def convert_field_keys(items): + if not isinstance(items, dict): + raise InvalidQueryError(invalid_argument_error) + return { + key.name if isinstance(key, Field) else key: value + for key, value in items.items() + } + + if items is not None and len(kwargs) > 0: + raise InvalidQueryError(invalid_argument_error) + DataRow = Entity.DataRow - if DataRow.row_data.name not in kwargs: + args = convert_field_keys(items) if items is not None else kwargs + + if DataRow.row_data.name not in args: raise InvalidQueryError( "DataRow.row_data missing when creating DataRow.") # If row data is a local file path, upload it to server. - row_data = kwargs[DataRow.row_data.name] + row_data = args[DataRow.row_data.name] if os.path.exists(row_data): - kwargs[DataRow.row_data.name] = self.client.upload_file(row_data) - kwargs[DataRow.dataset.name] = self + args[DataRow.row_data.name] = self.client.upload_file(row_data) + args[DataRow.dataset.name] = self # Parse metadata fields, if they are provided - if DataRow.custom_metadata.name in kwargs: + if DataRow.metadata_fields.name in args: mdo = self.client.get_data_row_metadata_ontology() - kwargs[DataRow.custom_metadata.name] = mdo.parse_upsert_metadata( - kwargs[DataRow.custom_metadata.name]) - - return self.client._create(DataRow, kwargs) + args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata( + args[DataRow.metadata_fields.name]) + return self.client._create(DataRow, args) def create_data_rows_sync(self, items) -> None: """ Synchronously bulk upload data rows. @@ -264,10 +281,10 @@ def validate_attachments(item): return attachments def parse_metadata_fields(item): - metadata_fields = item.get('custom_metadata') + metadata_fields = item.get('metadata_fields') if metadata_fields: mdo = self.client.get_data_row_metadata_ontology() - item['custom_metadata'] = mdo.parse_upsert_metadata( + item['metadata_fields'] = mdo.parse_upsert_metadata( metadata_fields) def format_row(item): @@ -413,9 +430,9 @@ def export_data_rows(self, timeout_seconds=120) -> Generator: response = requests.get(download_url) response.raise_for_status() reader = ndjson.reader(StringIO(response.text)) - # TODO: Update result to parse customMetadata when resolver returns + # TODO: Update result to parse metadataFields when resolver returns return (Entity.DataRow(self.client, { - **result, 'customMetadata': [] + **result, 'metadataFields': [] }) for result in reader) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") diff --git a/tests/integration/test_data_rows.py b/tests/integration/test_data_rows.py index dbc138fda..11db427dc 100644 --- a/tests/integration/test_data_rows.py +++ b/tests/integration/test_data_rows.py @@ -1,3 +1,4 @@ +import imghdr from tempfile import NamedTemporaryFile import uuid import time @@ -177,12 +178,63 @@ def test_data_row_single_creation(dataset, rand_gen, image_url): assert requests.get(data_row_2.row_data).content == data +def test_create_data_row_with_dict(dataset, image_url): + client = dataset.client + assert len(list(dataset.data_rows())) == 0 + dr = {"row_data": image_url} + data_row = dataset.create_data_row(dr) + assert len(list(dataset.data_rows())) == 1 + assert data_row.dataset() == dataset + assert data_row.created_by() == client.get_user() + assert data_row.organization() == client.get_organization() + assert requests.get(image_url).content == \ + requests.get(data_row.row_data).content + assert data_row.media_attributes is not None + + +def test_create_data_row_with_dict_containing_field(dataset, image_url): + client = dataset.client + assert len(list(dataset.data_rows())) == 0 + dr = {DataRow.row_data: image_url} + data_row = dataset.create_data_row(dr) + assert len(list(dataset.data_rows())) == 1 + assert data_row.dataset() == dataset + assert data_row.created_by() == client.get_user() + assert data_row.organization() == client.get_organization() + assert requests.get(image_url).content == \ + requests.get(data_row.row_data).content + assert data_row.media_attributes is not None + + +def test_create_data_row_with_dict_unpacked(dataset, image_url): + client = dataset.client + assert len(list(dataset.data_rows())) == 0 + dr = {"row_data": image_url} + data_row = dataset.create_data_row(**dr) + assert len(list(dataset.data_rows())) == 1 + assert data_row.dataset() == dataset + assert data_row.created_by() == client.get_user() + assert data_row.organization() == client.get_organization() + assert requests.get(image_url).content == \ + requests.get(data_row.row_data).content + assert data_row.media_attributes is not None + + +def test_create_data_row_with_invalid_input(dataset, image_url): + with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc: + dataset.create_data_row("asdf") + + dr = {"row_data": image_url} + with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc: + dataset.create_data_row(dr, row_data=image_url) + + def test_create_data_row_with_metadata(dataset, image_url): client = dataset.client assert len(list(dataset.data_rows())) == 0 data_row = dataset.create_data_row(row_data=image_url, - custom_metadata=make_metadata_fields()) + metadata_fields=make_metadata_fields()) assert len(list(dataset.data_rows())) == 1 assert data_row.dataset() == dataset @@ -191,8 +243,8 @@ def test_create_data_row_with_metadata(dataset, image_url): assert requests.get(image_url).content == \ requests.get(data_row.row_data).content assert data_row.media_attributes is not None - assert len(data_row.custom_metadata) == 5 - assert [m["schemaId"] for m in data_row.custom_metadata + assert len(data_row.metadata_fields) == 4 + assert [m["schemaId"] for m in data_row.metadata_fields ].sort() == EXPECTED_METADATA_SCHEMA_IDS @@ -201,7 +253,7 @@ def test_create_data_row_with_metadata_dict(dataset, image_url): assert len(list(dataset.data_rows())) == 0 data_row = dataset.create_data_row( - row_data=image_url, custom_metadata=make_metadata_fields_dict()) + row_data=image_url, metadata_fields=make_metadata_fields_dict()) assert len(list(dataset.data_rows())) == 1 assert data_row.dataset() == dataset @@ -210,8 +262,8 @@ def test_create_data_row_with_metadata_dict(dataset, image_url): assert requests.get(image_url).content == \ requests.get(data_row.row_data).content assert data_row.media_attributes is not None - assert len(data_row.custom_metadata) == 5 - assert [m["schemaId"] for m in data_row.custom_metadata + assert len(data_row.metadata_fields) == 4 + assert [m["schemaId"] for m in data_row.metadata_fields ].sort() == EXPECTED_METADATA_SCHEMA_IDS @@ -221,7 +273,7 @@ def test_create_data_row_with_invalid_metadata(dataset, image_url): DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID, value=[0.0] * 128)) with pytest.raises(labelbox.exceptions.MalformedQueryException) as excinfo: - dataset.create_data_row(row_data=image_url, custom_metadata=fields) + dataset.create_data_row(row_data=image_url, metadata_fields=fields) def test_create_data_rows_with_metadata(dataset, image_url): @@ -232,22 +284,22 @@ def test_create_data_rows_with_metadata(dataset, image_url): { DataRow.row_data: image_url, DataRow.external_id: "row1", - DataRow.custom_metadata: make_metadata_fields() + DataRow.metadata_fields: make_metadata_fields() }, { DataRow.row_data: image_url, DataRow.external_id: "row2", - "custom_metadata": make_metadata_fields() + "metadata_fields": make_metadata_fields() }, { DataRow.row_data: image_url, DataRow.external_id: "row3", - DataRow.custom_metadata: make_metadata_fields_dict() + DataRow.metadata_fields: make_metadata_fields_dict() }, { DataRow.row_data: image_url, DataRow.external_id: "row4", - "custom_metadata": make_metadata_fields_dict() + "metadata_fields": make_metadata_fields_dict() }, ]) task.wait_till_done() @@ -261,8 +313,8 @@ def test_create_data_rows_with_metadata(dataset, image_url): assert requests.get(image_url).content == \ requests.get(row.row_data).content assert row.media_attributes is not None - assert len(row.custom_metadata) == 5 - assert [m["schemaId"] for m in row.custom_metadata + assert len(row.metadata_fields) == 4 + assert [m["schemaId"] for m in row.metadata_fields ].sort() == EXPECTED_METADATA_SCHEMA_IDS @@ -273,7 +325,7 @@ def test_create_data_rows_with_invalid_metadata(dataset, image_url): task = dataset.create_data_rows([{ DataRow.row_data: image_url, - DataRow.custom_metadata: fields + DataRow.metadata_fields: fields }]) task.wait_till_done() assert task.status == "FAILED" @@ -288,7 +340,7 @@ def test_create_data_rows_with_metadata_missing_value(dataset, image_url): { DataRow.row_data: image_url, DataRow.external_id: "row1", - DataRow.custom_metadata: fields + DataRow.metadata_fields: fields }, ]) @@ -302,7 +354,7 @@ def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url): { DataRow.row_data: image_url, DataRow.external_id: "row1", - DataRow.custom_metadata: fields + DataRow.metadata_fields: fields }, ]) @@ -316,7 +368,7 @@ def test_create_data_rows_with_metadata_wrong_type(dataset, image_url): { DataRow.row_data: image_url, DataRow.external_id: "row1", - DataRow.custom_metadata: fields + DataRow.metadata_fields: fields }, ])