Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion labelbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from labelbox.schema.project import Project
from labelbox.schema.model import Model
from labelbox.schema.bulk_import_request import BulkImportRequest
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport, MEAToMALPredictionImport
from labelbox.schema.dataset import Dataset
from labelbox.schema.data_row import DataRow
from labelbox.schema.label import Label
Expand Down
1 change: 1 addition & 0 deletions labelbox/data/annotation_types/data/base_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ class BaseData(BaseModel, ABC):
"""
external_id: Optional[str] = None
uid: Optional[str] = None
global_key: Optional[str] = None
media_attributes: Optional[Dict[str, Any]] = None
metadata: Optional[List[Dict[str, Any]]] = None
5 changes: 3 additions & 2 deletions labelbox/data/annotation_types/data/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,10 @@ def validate_args(cls, values):
url = values.get("url")
arr = values.get("arr")
uid = values.get('uid')
if uid == file_path == im_bytes == url == None and arr is None:
global_key = values.get('global_key')
if uid == file_path == im_bytes == url == global_key == None and arr is None:
raise ValueError(
"One of `file_path`, `im_bytes`, `url`, `uid` or `arr` required."
"One of `file_path`, `im_bytes`, `url`, `uid`, `global_key` or `arr` required."
)
if arr is not None:
if arr.dtype != np.uint8:
Expand Down
6 changes: 4 additions & 2 deletions labelbox/data/annotation_types/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ def validate_date(cls, values):
text = values.get("text")
url = values.get("url")
uid = values.get('uid')
if uid == file_path == text == url == None:
global_key = values.get('global_key')
if uid == file_path == text == url == global_key == None:
raise ValueError(
"One of `file_path`, `text`, `uid`, or `url` required.")
"One of `file_path`, `text`, `uid`, `global_key` or `url` required."
)
return values

def __repr__(self) -> str:
Expand Down
6 changes: 4 additions & 2 deletions labelbox/data/annotation_types/data/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,12 @@ def validate_data(cls, values):
url = values.get("url")
frames = values.get("frames")
uid = values.get("uid")
global_key = values.get("global_key")

if uid == file_path == frames == url == None:
if uid == file_path == frames == url == global_key == None:
raise ValueError(
"One of `file_path`, `frames`, `uid`, or `url` required.")
"One of `file_path`, `frames`, `uid`, `global_key` or `url` required."
)
return values

def __repr__(self) -> str:
Expand Down
31 changes: 17 additions & 14 deletions labelbox/data/serialization/ndjson/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,37 @@
from uuid import uuid4
from pydantic import BaseModel, root_validator, validator, Field

from labelbox.utils import camel_case
from labelbox.utils import _CamelCaseMixin, camel_case, is_exactly_one_set
from ...annotation_types.types import Cuid


class DataRow(BaseModel):
class DataRow(_CamelCaseMixin):
id: str = None
global_key: str = None

@validator('id', pre=True, always=True)
def validate_id(cls, v):
if v is None:
raise ValueError(
"Data row ids are not set. Use `LabelGenerator.add_to_dataset`,or `Label.create_data_row`. "
"You can also manually assign the id for each `BaseData` object"
)
return v
@root_validator()
def must_set_one(cls, values):
if not is_exactly_one_set(values.get('id'), values.get('global_key')):
raise ValueError("Must set either id or global_key")
return values


class NDJsonBase(BaseModel):
class NDJsonBase(_CamelCaseMixin):
uuid: str = None
data_row: DataRow

@validator('uuid', pre=True, always=True)
def set_id(cls, v):
return v or str(uuid4())

class Config:
allow_population_by_field_name = True
alias_generator = camel_case
def dict(self, *args, **kwargs):
""" Pop missing id or missing globalKey from dataRow """
res = super().dict(*args, **kwargs)
if not self.data_row.id:
res['dataRow'].pop('id')
if not self.data_row.global_key:
res['dataRow'].pop('globalKey')
return res


class NDAnnotation(NDJsonBase):
Expand Down
8 changes: 4 additions & 4 deletions labelbox/data/serialization/ndjson/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio
from ...annotation_types.types import Cuid
from ...annotation_types.data import TextData, VideoData, ImageData
from .base import NDAnnotation
from .base import DataRow, NDAnnotation


class NDFeature(ConfidenceMixin):
Expand Down Expand Up @@ -125,7 +125,7 @@ def from_common(cls, text: Text, name: str, feature_schema_id: Cuid,
ImageData]) -> "NDText":
return cls(
answer=text.answer,
data_row={'id': data.uid},
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
Expand All @@ -145,7 +145,7 @@ def from_common(
confidence=answer.confidence)
for answer in checklist.answer
],
data_row={'id': data.uid},
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
Expand All @@ -161,7 +161,7 @@ def from_common(cls, radio: Radio, name: str, feature_schema_id: Cuid,
return cls(answer=NDFeature(name=radio.answer.name,
schema_id=radio.answer.feature_schema_id,
confidence=radio.answer.confidence),
data_row={'id': data.uid},
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
Expand Down
37 changes: 25 additions & 12 deletions labelbox/data/serialization/ndjson/label.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from itertools import groupby
from operator import itemgetter
from typing import Dict, Generator, List, Tuple, Union
from typing import Dict, Generator, List, Optional, Tuple, Union
from collections import defaultdict
import warnings

Expand All @@ -17,6 +17,7 @@
from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass
from .objects import NDObject, NDObjectType, NDSegments
from .base import DataRow


class NDLabel(BaseModel):
Expand All @@ -27,7 +28,9 @@ class NDLabel(BaseModel):
def to_common(self) -> LabelGenerator:
grouped_annotations = defaultdict(list)
for annotation in self.annotations:
grouped_annotations[annotation.data_row.id].append(annotation)
grouped_annotations[annotation.data_row.id or
annotation.data_row.global_key].append(
annotation)
return LabelGenerator(
data=self._generate_annotations(grouped_annotations))

Expand All @@ -45,9 +48,11 @@ def _generate_annotations(
NDConfusionMatrixMetric,
NDScalarMetric, NDSegments]]]
) -> Generator[Label, None, None]:
for data_row_id, annotations in grouped_annotations.items():
for _, annotations in grouped_annotations.items():
annots = []
data_row = annotations[0].data_row
for annotation in annotations:

if isinstance(annotation, NDSegments):
annots.extend(
NDSegments.to_common(annotation, annotation.name,
Expand All @@ -62,22 +67,30 @@ def _generate_annotations(
else:
raise TypeError(
f"Unsupported annotation. {type(annotation)}")
data = self._infer_media_type(annots)(uid=data_row_id)
yield Label(annotations=annots, data=data)
yield Label(annotations=annots,
data=self._infer_media_type(data_row, annots))

def _infer_media_type(
self, annotations: List[Union[TextEntity, VideoClassificationAnnotation,
VideoObjectAnnotation, ObjectAnnotation,
ClassificationAnnotation, ScalarMetric,
ConfusionMatrixMetric]]
self, data_row: DataRow,
annotations: List[Union[TextEntity, VideoClassificationAnnotation,
VideoObjectAnnotation, ObjectAnnotation,
ClassificationAnnotation, ScalarMetric,
ConfusionMatrixMetric]]
) -> Union[TextData, VideoData, ImageData]:
if len(annotations) == 0:
raise ValueError("Missing annotations while inferring media type")

types = {type(annotation) for annotation in annotations}
data = ImageData
if TextEntity in types:
return TextData
data = TextData
elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types:
return VideoData
data = VideoData

if data_row.id:
return data(uid=data_row.id)
else:
return ImageData
return data(global_key=data_row.global_key)

@staticmethod
def _get_consecutive_frames(
Expand Down
6 changes: 3 additions & 3 deletions labelbox/data/serialization/ndjson/metric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Union, Type

from labelbox.data.annotation_types.data import ImageData, TextData
from labelbox.data.serialization.ndjson.base import NDJsonBase
from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase
from labelbox.data.annotation_types.metrics.scalar import (
ScalarMetric, ScalarMetricAggregation, ScalarMetricValue,
ScalarMetricConfidenceValue)
Expand Down Expand Up @@ -50,7 +50,7 @@ def from_common(
feature_name=metric.feature_name,
subclass_name=metric.subclass_name,
aggregation=metric.aggregation,
data_row={'id': data.uid})
data_row=DataRow(id=data.uid, global_key=data.global_key))


class NDScalarMetric(BaseNDMetric):
Expand All @@ -75,7 +75,7 @@ def from_common(cls, metric: ScalarMetric,
feature_name=metric.feature_name,
subclass_name=metric.subclass_name,
aggregation=metric.aggregation.value,
data_row={'id': data.uid})
data_row=DataRow(id=data.uid, global_key=data.global_key))

def dict(self, *args, **kwargs):
res = super().dict(*args, **kwargs)
Expand Down
14 changes: 7 additions & 7 deletions labelbox/data/serialization/ndjson/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def from_common(cls,
'x': point.x,
'y': point.y
},
dataRow=DataRow(id=data.uid),
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
Expand Down Expand Up @@ -105,7 +105,7 @@ def from_common(cls,
'x': pt.x,
'y': pt.y
} for pt in line.points],
dataRow=DataRow(id=data.uid),
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
Expand Down Expand Up @@ -154,7 +154,7 @@ def from_common(cls,
'x': pt.x,
'y': pt.y
} for pt in polygon.points],
dataRow=DataRow(id=data.uid),
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
Expand Down Expand Up @@ -183,7 +183,7 @@ def from_common(cls,
left=rectangle.start.x,
height=rectangle.end.y - rectangle.start.y,
width=rectangle.end.x - rectangle.start.x),
dataRow=DataRow(id=data.uid),
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
Expand Down Expand Up @@ -280,7 +280,7 @@ def from_common(cls, segments: List[VideoObjectAnnotation], data: VideoData,
segments = [NDSegment.from_common(segment) for segment in segments]

return cls(segments=segments,
dataRow=DataRow(id=data.uid),
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'))
Expand Down Expand Up @@ -332,7 +332,7 @@ def from_common(cls,
png=base64.b64encode(im_bytes.getvalue()).decode('utf-8'))

return cls(mask=lbv1_mask,
dataRow=DataRow(id=data.uid),
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
Expand Down Expand Up @@ -364,7 +364,7 @@ def from_common(cls,
start=text_entity.start,
end=text_entity.end,
),
dataRow=DataRow(id=data.uid),
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
Expand Down
34 changes: 34 additions & 0 deletions labelbox/schema/annotation_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from labelbox.orm import query
from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Field, Relationship
from labelbox.utils import is_exactly_one_set
from labelbox.schema.confidence_presence_checker import LabelsConfidencePresenceChecker
from labelbox.schema.enums import AnnotationImportState
from labelbox.schema.serialization import serialize_labels
Expand Down Expand Up @@ -155,6 +156,8 @@ def _get_ndjson_from_objects(cls, objects: Union[List[Dict[str, Any]],
)

objects = serialize_labels(objects)
cls._validate_data_rows(objects)

data_str = ndjson.dumps(objects)
if not data_str:
raise ValueError(f"{object_name} cannot be empty")
Expand All @@ -171,6 +174,37 @@ def refresh(self) -> None:
as_json=True)
self._set_field_values(res)

@classmethod
def _validate_data_rows(cls, objects: List[Dict[str, Any]]):
"""
Validates annotations by checking 'dataRow' is provided
and only one of 'id' or 'globalKey' is provided.

Shows up to `max_num_errors` errors if invalidated, to prevent
large number of error messages from being printed out
"""
errors = []
max_num_errors = 100
for object in objects:
if 'dataRow' not in object:
errors.append(f"'dataRow' is missing in {object}")
elif not is_exactly_one_set(object['dataRow'].get('id'),
object['dataRow'].get('globalKey')):
errors.append(
f"Must provide only one of 'id' or 'globalKey' for 'dataRow' in {object}"
)

if errors:
errors_length = len(errors)
formatted_errors = '\n'.join(errors[:max_num_errors])
if errors_length > max_num_errors:
logger.warning(
f"Found more than {max_num_errors} errors. Showing first {max_num_errors} error messages..."
)
raise ValueError(
f"Error while validating annotations. Found {errors_length} annotations with errors. Errors:\n{formatted_errors}"
)

@classmethod
def from_name(cls,
client: "labelbox.Client",
Expand Down
4 changes: 4 additions & 0 deletions labelbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def snake_case(s):
return _convert(s, "_", lambda i: False)


def is_exactly_one_set(x, y):
return not (bool(x) == bool(y))


class _CamelCaseMixin(BaseModel):

class Config:
Expand Down
Loading