diff --git a/labelbox/orm/db_object.py b/labelbox/orm/db_object.py index cc10088c8..5b986afcb 100644 --- a/labelbox/orm/db_object.py +++ b/labelbox/orm/db_object.py @@ -70,6 +70,15 @@ def _set_field_values(self, field_values): "field %s", value, field) elif isinstance(field.field_type, Field.EnumType): value = field.field_type.enum_cls(value) + elif isinstance(field.field_type, Field.ListType): + if field.field_type.list_cls.__name__ == "DataRowMetadataField": + mdo = self.client.get_data_row_metadata_ontology() + try: + value = mdo.parse_metadata_fields(value) + except ValueError: + logger.warning( + "Failed to convert value '%s' to metadata for field %s", + value, field) setattr(self, field.name, value) def __repr__(self): diff --git a/labelbox/schema/batch.py b/labelbox/schema/batch.py index 684ec60e9..b4c2373f0 100644 --- a/labelbox/schema/batch.py +++ b/labelbox/schema/batch.py @@ -106,7 +106,8 @@ def export_data_rows(self, timeout_seconds=120) -> Generator: reader = ndjson.reader(StringIO(response.text)) # TODO: Update result to parse metadataFields when resolver returns return (Entity.DataRow(self.client, { - **result, 'metadataFields': [] + **result, 'metadataFields': [], + 'customMetadata': [] }) for result in reader) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") diff --git a/labelbox/schema/data_row.py b/labelbox/schema/data_row.py index 01b803ccc..76ee4bfcf 100644 --- a/labelbox/schema/data_row.py +++ b/labelbox/schema/data_row.py @@ -21,8 +21,9 @@ class DataRow(DbObject, Updateable, BulkDeletable): Otherwise, it's treated as an external URL. updated_at (datetime) created_at (datetime) - media_attributes (dict): generated media attributes for the datarow - metadata_fields (list): metadata associated with the datarow + media_attributes (dict): generated media attributes for the data row + metadata_fields (list): metadata associated with the data row + metadata (list): metadata associated with the data row as list of DataRowMetadataField dataset (Relationship): `ToOne` relationship to Dataset created_by (Relationship): `ToOne` relationship to User @@ -36,10 +37,14 @@ class DataRow(DbObject, Updateable, BulkDeletable): created_at = Field.DateTime("created_at") media_attributes = Field.Json("media_attributes") metadata_fields = Field.List( - DataRowMetadataField, + dict, graphql_type="DataRowCustomMetadataUpsertInput!", name="metadata_fields", result_subquery="metadataFields { schemaId name value kind }") + metadata = Field.List(DataRowMetadataField, + name="metadata", + graphql_name="customMetadata", + result_subquery="customMetadata { schemaId value }") # Relationships dataset = Relationship.ToOne("Dataset") diff --git a/labelbox/schema/data_row_metadata.py b/labelbox/schema/data_row_metadata.py index fa4d3aacc..47d6014e3 100644 --- a/labelbox/schema/data_row_metadata.py +++ b/labelbox/schema/data_row_metadata.py @@ -1,6 +1,5 @@ # type: ignore from datetime import datetime -import warnings from copy import deepcopy from enum import Enum from itertools import chain @@ -224,34 +223,53 @@ def parse_metadata( for dr in unparsed: fields = [] - for f in dr["fields"]: - if f["schemaId"] not in self.fields_by_id: - # Update metadata ontology if field can't be found - self.refresh_ontology() - if f["schemaId"] not in self.fields_by_id: - raise ValueError( - f"Schema Id `{f['schemaId']}` not found in ontology" - ) - - schema = self.fields_by_id[f["schemaId"]] - if schema.kind == DataRowMetadataKind.enum: - continue - elif schema.kind == DataRowMetadataKind.option: - field = DataRowMetadataField(schema_id=schema.parent, - value=schema.uid) - elif schema.kind == DataRowMetadataKind.datetime: - field = DataRowMetadataField( - schema_id=schema.uid, - value=datetime.fromisoformat(f["value"][:-1] + - "+00:00")) - else: - field = DataRowMetadataField(schema_id=schema.uid, - value=f["value"]) - fields.append(field) + if "fields" in dr: + fields = self.parse_metadata_fields(dr["fields"]) parsed.append( DataRowMetadata(data_row_id=dr["dataRowId"], fields=fields)) return parsed + def parse_metadata_fields( + self, unparsed: List[Dict[str, + Dict]]) -> List[DataRowMetadataField]: + """ Parse metadata fields as list of `DataRowMetadataField` + + >>> mdo.parse_metadata_fields([metadata_fields]) + + Args: + unparsed: An unparsed list of metadata represented as a dict containing 'schemaId' and 'value' + + Returns: + metadata: List of `DataRowMetadataField` + """ + parsed = [] + if isinstance(unparsed, dict): + raise ValueError("Pass a list of dictionaries") + + for f in unparsed: + if f["schemaId"] not in self.fields_by_id: + # Update metadata ontology if field can't be found + self.refresh_ontology() + if f["schemaId"] not in self.fields_by_id: + raise ValueError( + f"Schema Id `{f['schemaId']}` not found in ontology") + + schema = self.fields_by_id[f["schemaId"]] + if schema.kind == DataRowMetadataKind.enum: + continue + elif schema.kind == DataRowMetadataKind.option: + field = DataRowMetadataField(schema_id=schema.parent, + value=schema.uid) + elif schema.kind == DataRowMetadataKind.datetime: + field = DataRowMetadataField( + schema_id=schema.uid, + value=datetime.fromisoformat(f["value"][:-1] + "+00:00")) + else: + field = DataRowMetadataField(schema_id=schema.uid, + value=f["value"]) + parsed.append(field) + return parsed + def bulk_upsert( self, metadata: List[DataRowMetadata] ) -> List[DataRowMetadataBatchResponse]: diff --git a/labelbox/schema/dataset.py b/labelbox/schema/dataset.py index 79b43ef95..9ba1d86de 100644 --- a/labelbox/schema/dataset.py +++ b/labelbox/schema/dataset.py @@ -442,7 +442,8 @@ def export_data_rows(self, timeout_seconds=120) -> Generator: reader = ndjson.reader(StringIO(response.text)) # TODO: Update result to parse metadataFields when resolver returns return (Entity.DataRow(self.client, { - **result, 'metadataFields': [] + **result, 'metadataFields': [], + 'customMetadata': [] }) for result in reader) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") diff --git a/tests/integration/test_data_row_metadata.py b/tests/integration/test_data_row_metadata.py index 5f715e2c7..8f8b9207e 100644 --- a/tests/integration/test_data_row_metadata.py +++ b/tests/integration/test_data_row_metadata.py @@ -1,4 +1,3 @@ -import time from datetime import datetime import pytest @@ -282,3 +281,34 @@ def test_parse_raw_metadata(mdo): for row in parsed: for field in row.fields: assert mdo._parse_upsert(field) + + +def test_parse_raw_metadata_fields(mdo): + example = [ + { + 'schemaId': 'cko8s9r5v0001h2dk9elqdidh', + 'value': 'my-new-message' + }, + { + 'schemaId': 'cko8sbczn0002h2dkdaxb5kal', + 'value': {} + }, + { + 'schemaId': 'cko8sbscr0003h2dk04w86hof', + 'value': {} + }, + { + 'schemaId': 'cko8sdzv70006h2dk8jg64zvb', + 'value': '2021-07-20T21:41:14.606710Z' + }, + { + 'schemaId': FAKE_SCHEMA_ID, + 'value': 0.5 + }, + ] + + parsed = mdo.parse_metadata_fields(example) + assert len(parsed) == 4 + + for field in parsed: + assert mdo._parse_upsert(field) diff --git a/tests/integration/test_data_rows.py b/tests/integration/test_data_rows.py index 5a814a25b..ca88808dd 100644 --- a/tests/integration/test_data_rows.py +++ b/tests/integration/test_data_rows.py @@ -1,7 +1,5 @@ -import imghdr from tempfile import NamedTemporaryFile import uuid -import time from datetime import datetime import pytest @@ -22,6 +20,14 @@ ].sort() +@pytest.fixture +def mdo(client): + mdo = client.get_data_row_metadata_ontology() + mdo._raw_ontology = mdo._get_ontology() + mdo._build_ontology() + yield mdo + + def make_metadata_fields(): embeddings = [0.0] * 128 msg = "A message" @@ -57,12 +63,6 @@ def make_metadata_fields_dict(): return fields -def filter_precomputed_embeddings(metadata_fields): - return list( - filter(lambda md: md["name"] != "precomputedImageEmbedding", - metadata_fields)) - - def test_get_data_row(datarow, client): assert client.get_data_row(datarow.uid) @@ -235,7 +235,7 @@ def test_create_data_row_with_invalid_input(dataset, image_url): dataset.create_data_row(dr, row_data=image_url) -def test_create_data_row_with_metadata(dataset, image_url): +def test_create_data_row_with_metadata(mdo, dataset, image_url): client = dataset.client assert len(list(dataset.data_rows())) == 0 @@ -249,13 +249,17 @@ def test_create_data_row_with_metadata(dataset, image_url): assert requests.get(image_url).content == \ requests.get(data_row.row_data).content assert data_row.media_attributes is not None - filtered_md_fields = filter_precomputed_embeddings(data_row.metadata_fields) - assert len(filtered_md_fields) == 4 - assert [m["schemaId"] for m in filtered_md_fields + metadata_fields = data_row.metadata_fields + metadata = data_row.metadata + assert len(metadata_fields) == 4 + assert len(metadata) == 4 + assert [m["schemaId"] for m in metadata_fields ].sort() == EXPECTED_METADATA_SCHEMA_IDS + for m in metadata: + assert mdo._parse_upsert(m) -def test_create_data_row_with_metadata_dict(dataset, image_url): +def test_create_data_row_with_metadata_dict(mdo, dataset, image_url): client = dataset.client assert len(list(dataset.data_rows())) == 0 @@ -269,10 +273,14 @@ def test_create_data_row_with_metadata_dict(dataset, image_url): assert requests.get(image_url).content == \ requests.get(data_row.row_data).content assert data_row.media_attributes is not None - filtered_md_fields = filter_precomputed_embeddings(data_row.metadata_fields) - assert len(filtered_md_fields) == 4 - assert [m["schemaId"] for m in filtered_md_fields + metadata_fields = data_row.metadata_fields + metadata = data_row.metadata + assert len(metadata_fields) == 4 + assert len(metadata) == 4 + assert [m["schemaId"] for m in metadata_fields ].sort() == EXPECTED_METADATA_SCHEMA_IDS + for m in metadata: + assert mdo._parse_upsert(m) def test_create_data_row_with_invalid_metadata(dataset, image_url): @@ -284,7 +292,7 @@ def test_create_data_row_with_invalid_metadata(dataset, image_url): dataset.create_data_row(row_data=image_url, metadata_fields=fields) -def test_create_data_rows_with_metadata(dataset, image_url): +def test_create_data_rows_with_metadata(mdo, dataset, image_url): client = dataset.client assert len(list(dataset.data_rows())) == 0 @@ -322,11 +330,14 @@ def test_create_data_rows_with_metadata(dataset, image_url): requests.get(row.row_data).content assert row.media_attributes is not None - # Remove 'precomputedImageEmbedding' metadata if automatically added - filtered_md_fields = filter_precomputed_embeddings(row.metadata_fields) - assert len(filtered_md_fields) == 4 - assert [m["schemaId"] for m in filtered_md_fields + metadata_fields = row.metadata_fields + metadata = row.metadata + assert len(metadata_fields) == 4 + assert len(metadata) == 4 + assert [m["schemaId"] for m in metadata_fields ].sort() == EXPECTED_METADATA_SCHEMA_IDS + for m in metadata: + assert mdo._parse_upsert(m) def test_create_data_rows_with_invalid_metadata(dataset, image_url):