diff --git a/labelbox/schema/dataset.py b/labelbox/schema/dataset.py index e43577112..3c98a966d 100644 --- a/labelbox/schema/dataset.py +++ b/labelbox/schema/dataset.py @@ -24,6 +24,7 @@ from labelbox.pagination import PaginatedCollection from labelbox.pydantic_compat import BaseModel from labelbox.schema.data_row import DataRow +from labelbox.schema.embeddings import EmbeddingVector from labelbox.schema.export_filters import DatasetExportFilters, build_filters from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params from labelbox.schema.export_task import ExportTask @@ -179,6 +180,11 @@ def convert_field_keys(items): args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata( args[DataRow.metadata_fields.name]) + if "embeddings" in args: + args["embeddings"] = [ + EmbeddingVector(**e).to_gql() for e in args["embeddings"] + ] + query_str = """mutation CreateDataRowPyApi( $row_data: String!, $metadata_fields: [DataRowCustomMetadataUpsertInput!], @@ -186,7 +192,8 @@ def convert_field_keys(items): $media_type : MediaType, $external_id : String, $global_key : String, - $dataset: ID! + $dataset: ID!, + $embeddings: [DataRowEmbeddingVectorInput!] ){ createDataRow( data: @@ -198,6 +205,7 @@ def convert_field_keys(items): globalKey: $global_key attachments: $attachments dataset: {connect: {id: $dataset}} + embeddings: $embeddings } ) {%s} @@ -388,6 +396,13 @@ def validate_attachments(item): ) return attachments + def validate_embeddings(item): + embeddings = item.get("embeddings") + if embeddings: + item["embeddings"] = [ + EmbeddingVector(**e).to_gql() for e in embeddings + ] + def validate_conversational_data(conversational_data: list) -> None: """ Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json @@ -448,7 +463,9 @@ def validate_keys(item): str) and item.get('row_data').startswith("s3:/"): raise InvalidQueryError( "row_data: s3 assets must start with 'https'.") - allowed_extra_fields = {'attachments', 'media_type', 'dataset_id'} + allowed_extra_fields = { + 'attachments', 'media_type', 'dataset_id', 'embeddings' + } invalid_keys = set(item) - {f.name for f in DataRow.fields() } - allowed_extra_fields if invalid_keys: @@ -494,6 +511,8 @@ def convert_item(data_row_item): validate_keys(item) # Make sure attachments are valid validate_attachments(item) + # Make sure embeddings are valid + validate_embeddings(item) # Parse metadata fields if they exist parse_metadata_fields(item) # Upload any local file paths diff --git a/labelbox/schema/embeddings.py b/labelbox/schema/embeddings.py new file mode 100644 index 000000000..2ba687322 --- /dev/null +++ b/labelbox/schema/embeddings.py @@ -0,0 +1,15 @@ +from typing import List, Optional, Any, Dict + +from labelbox.pydantic_compat import BaseModel + + +class EmbeddingVector(BaseModel): + embedding_id: str + vector: List[float] + clusters: Optional[List[int]] + + def to_gql(self) -> Dict[str, Any]: + result = {"embeddingId": self.embedding_id, "vector": self.vector} + if self.clusters: + result["clusters"] = self.clusters + return result