From 5ffc6dc2825b00260b75a64ecf013a6649a5fb57 Mon Sep 17 00:00:00 2001 From: Ibrahim Muhammad Date: Fri, 17 Mar 2023 15:14:07 -0700 Subject: [PATCH] Fix annotation data type coersion by Pydantic --- labelbox/data/annotation_types/data/audio.py | 6 +++-- .../annotation_types/data/conversation.py | 4 +++- labelbox/data/annotation_types/data/dicom.py | 6 +++-- .../data/annotation_types/data/document.py | 6 +++-- labelbox/data/annotation_types/data/html.py | 6 +++-- labelbox/data/annotation_types/data/text.py | 5 +++- labelbox/data/annotation_types/label.py | 7 ++++-- labelbox/typing_imports.py | 10 ++++++++ labelbox/utils.py | 23 +++++++++++++++++++ tests/data/annotation_types/test_label.py | 12 ++++++++++ 10 files changed, 73 insertions(+), 12 deletions(-) create mode 100644 labelbox/typing_imports.py diff --git a/labelbox/data/annotation_types/data/audio.py b/labelbox/data/annotation_types/data/audio.py index 5263d9d2a..76be33110 100644 --- a/labelbox/data/annotation_types/data/audio.py +++ b/labelbox/data/annotation_types/data/audio.py @@ -1,5 +1,7 @@ +from labelbox.typing_imports import Literal +from labelbox.utils import _NoCoercionMixin from .base_data import BaseData -class AudioData(BaseData): - ... \ No newline at end of file +class AudioData(BaseData, _NoCoercionMixin): + class_name: Literal["AudioData"] = "AudioData" \ No newline at end of file diff --git a/labelbox/data/annotation_types/data/conversation.py b/labelbox/data/annotation_types/data/conversation.py index 3d7633f28..302b2c487 100644 --- a/labelbox/data/annotation_types/data/conversation.py +++ b/labelbox/data/annotation_types/data/conversation.py @@ -1,5 +1,7 @@ +from labelbox.typing_imports import Literal +from labelbox.utils import _NoCoercionMixin from .base_data import BaseData class ConversationData(BaseData): - ... \ No newline at end of file + class_name: Literal["ConversationData"] = "ConversationData" \ No newline at end of file diff --git a/labelbox/data/annotation_types/data/dicom.py b/labelbox/data/annotation_types/data/dicom.py index 9ebd242b8..753475c3e 100644 --- a/labelbox/data/annotation_types/data/dicom.py +++ b/labelbox/data/annotation_types/data/dicom.py @@ -1,5 +1,7 @@ +from labelbox.typing_imports import Literal +from labelbox.utils import _NoCoercionMixin from .base_data import BaseData -class DicomData(BaseData): - ... \ No newline at end of file +class DicomData(BaseData, _NoCoercionMixin): + class_name: Literal["DicomData"] = "DicomData" \ No newline at end of file diff --git a/labelbox/data/annotation_types/data/document.py b/labelbox/data/annotation_types/data/document.py index 8488812ca..5b2610c5b 100644 --- a/labelbox/data/annotation_types/data/document.py +++ b/labelbox/data/annotation_types/data/document.py @@ -1,5 +1,7 @@ +from labelbox.typing_imports import Literal +from labelbox.utils import _NoCoercionMixin from .base_data import BaseData -class DocumentData(BaseData): - ... \ No newline at end of file +class DocumentData(BaseData, _NoCoercionMixin): + class_name: Literal["DocumentData"] = "DocumentData" \ No newline at end of file diff --git a/labelbox/data/annotation_types/data/html.py b/labelbox/data/annotation_types/data/html.py index f9fc31e3f..1820ce467 100644 --- a/labelbox/data/annotation_types/data/html.py +++ b/labelbox/data/annotation_types/data/html.py @@ -1,5 +1,7 @@ +from labelbox.typing_imports import Literal +from labelbox.utils import _NoCoercionMixin from .base_data import BaseData -class HTMLData(BaseData): - ... \ No newline at end of file +class HTMLData(BaseData, _NoCoercionMixin): + class_name: Literal["HTMLData"] = "HTMLData" \ No newline at end of file diff --git a/labelbox/data/annotation_types/data/text.py b/labelbox/data/annotation_types/data/text.py index 9149b0394..704d2e7b4 100644 --- a/labelbox/data/annotation_types/data/text.py +++ b/labelbox/data/annotation_types/data/text.py @@ -6,10 +6,12 @@ from pydantic import root_validator from labelbox.exceptions import InternalServerError +from labelbox.typing_imports import Literal +from labelbox.utils import _NoCoercionMixin from .base_data import BaseData -class TextData(BaseData): +class TextData(BaseData, _NoCoercionMixin): """ Represents text data. Requires arg file_path, text, or url @@ -20,6 +22,7 @@ class TextData(BaseData): text (str) url (str) """ + class_name: Literal["TextData"] = "TextData" file_path: Optional[str] = None text: Optional[str] = None url: Optional[str] = None diff --git a/labelbox/data/annotation_types/label.py b/labelbox/data/annotation_types/label.py index 962dc402e..5aa68a8f4 100644 --- a/labelbox/data/annotation_types/label.py +++ b/labelbox/data/annotation_types/label.py @@ -11,12 +11,15 @@ VideoClassificationAnnotation, VideoObjectAnnotation, DICOMObjectAnnotation) from .classification import ClassificationAnswer -from .data import DicomData, VideoData, TextData, ImageData +from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, MaskData, TextData, VideoData from .geometry import Mask from .metrics import ScalarMetric, ConfusionMatrixMetric from .types import Cuid from ..ontology import get_feature_schema_lookup +DataType = Union[VideoData, ImageData, TextData, TiledImageData, AudioData, + ConversationData, DicomData, DocumentData, HTMLData] + class Label(BaseModel): """Container for holding data and annotations @@ -38,7 +41,7 @@ class Label(BaseModel): extra: additional context """ uid: Optional[Cuid] = None - data: Union[VideoData, ImageData, TextData, TiledImageData] + data: DataType annotations: List[Union[ClassificationAnnotation, ObjectAnnotation, ScalarMetric, ConfusionMatrixMetric]] = [] extra: Dict[str, Any] = {} diff --git a/labelbox/typing_imports.py b/labelbox/typing_imports.py new file mode 100644 index 000000000..2c2716710 --- /dev/null +++ b/labelbox/typing_imports.py @@ -0,0 +1,10 @@ +""" +This module imports types that differ across python versions, so other modules +don't have to worry about where they should be imported from. +""" + +import sys +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal \ No newline at end of file diff --git a/labelbox/utils.py b/labelbox/utils.py index 27106abb5..dffa5694e 100644 --- a/labelbox/utils.py +++ b/labelbox/utils.py @@ -35,3 +35,26 @@ class _CamelCaseMixin(BaseModel): class Config: allow_population_by_field_name = True alias_generator = camel_case + + +class _NoCoercionMixin: + """ + When using Unions in type annotations, pydantic will try to coerce the type + of the object to the type of the first Union member. Which results in + uninteded behavior. + + This mixin uses a class_name discriminator field to prevent pydantic from + corecing the type of the object. Add a class_name field to the class you + want to discrimniate and use this mixin class to remove the discriminator + when serializing the object. + + Example: + class ConversationData(BaseData, _NoCoercionMixin): + class_name: Literal["ConversationData"] = "ConversationData" + + """ + + def dict(self, *args, **kwargs): + res = super().dict(*args, **kwargs) + res.pop('class_name') + return res diff --git a/tests/data/annotation_types/test_label.py b/tests/data/annotation_types/test_label.py index 105f6c939..ee83b1b50 100644 --- a/tests/data/annotation_types/test_label.py +++ b/tests/data/annotation_types/test_label.py @@ -1,5 +1,6 @@ import numpy as np +import labelbox.types as lb_types from labelbox import OntologyBuilder, Tool, Classification as OClassification, Option from labelbox.data.annotation_types import (ClassificationAnswer, Radio, Text, ClassificationAnnotation, @@ -181,3 +182,14 @@ def test_schema_assignment_confidence(): ]) assert label.annotations[0].confidence == 0.914 + + +def test_initialize_label_no_coercion(): + global_key = 'global-key' + ner_annotation = lb_types.ObjectAnnotation( + name="ner", + value=lb_types.ConversationEntity(start=0, end=8, message_id="4")) + label = Label(data=lb_types.ConversationData(global_key=global_key), + annotations=[ner_annotation]) + assert isinstance(label.data, lb_types.ConversationData) + assert label.data.global_key == global_key