Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions labelbox/schema/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from labelbox.pagination import PaginatedCollection
from labelbox.pydantic_compat import BaseModel
from labelbox.schema.data_row import DataRow
from labelbox.schema.embeddings import EmbeddingVector
from labelbox.schema.export_filters import DatasetExportFilters, build_filters
from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params
from labelbox.schema.export_task import ExportTask
Expand Down Expand Up @@ -179,14 +180,20 @@ def convert_field_keys(items):
args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata(
args[DataRow.metadata_fields.name])

if "embeddings" in args:
args["embeddings"] = [
EmbeddingVector(**e).to_gql() for e in args["embeddings"]
]

query_str = """mutation CreateDataRowPyApi(
$row_data: String!,
$metadata_fields: [DataRowCustomMetadataUpsertInput!],
$attachments: [DataRowAttachmentInput!],
$media_type : MediaType,
$external_id : String,
$global_key : String,
$dataset: ID!
$dataset: ID!,
$embeddings: [DataRowEmbeddingVectorInput!]
){
createDataRow(
data:
Expand All @@ -198,6 +205,7 @@ def convert_field_keys(items):
globalKey: $global_key
attachments: $attachments
dataset: {connect: {id: $dataset}}
embeddings: $embeddings
}
)
{%s}
Expand Down Expand Up @@ -388,6 +396,13 @@ def validate_attachments(item):
)
return attachments

def validate_embeddings(item):
embeddings = item.get("embeddings")
if embeddings:
item["embeddings"] = [
EmbeddingVector(**e).to_gql() for e in embeddings
]

def validate_conversational_data(conversational_data: list) -> None:
"""
Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json
Expand Down Expand Up @@ -448,7 +463,9 @@ def validate_keys(item):
str) and item.get('row_data').startswith("s3:/"):
raise InvalidQueryError(
"row_data: s3 assets must start with 'https'.")
allowed_extra_fields = {'attachments', 'media_type', 'dataset_id'}
allowed_extra_fields = {
'attachments', 'media_type', 'dataset_id', 'embeddings'
}
invalid_keys = set(item) - {f.name for f in DataRow.fields()
} - allowed_extra_fields
if invalid_keys:
Expand Down Expand Up @@ -494,6 +511,8 @@ def convert_item(data_row_item):
validate_keys(item)
# Make sure attachments are valid
validate_attachments(item)
# Make sure embeddings are valid
validate_embeddings(item)
# Parse metadata fields if they exist
parse_metadata_fields(item)
# Upload any local file paths
Expand Down
15 changes: 15 additions & 0 deletions labelbox/schema/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import List, Optional, Any, Dict

from labelbox.pydantic_compat import BaseModel


class EmbeddingVector(BaseModel):
embedding_id: str
vector: List[float]
clusters: Optional[List[int]]

def to_gql(self) -> Dict[str, Any]:
Copy link
Contributor Author

@mrobers1982 mrobers1982 Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to convert snake_case (from Python) to camelCase (to GQL).

result = {"embeddingId": self.embedding_id, "vector": self.vector}
if self.clusters:
result["clusters"] = self.clusters
return result