Skip to content

Commit a50ca2b

Browse files
committed
Fix annotation data type coersion by Pydantic
1 parent 18fe1c1 commit a50ca2b

File tree

9 files changed

+69
-11
lines changed

9 files changed

+69
-11
lines changed
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from labelbox.typing_imports import Literal
2+
from labelbox.utils import _NoCoercionMixin
13
from .base_data import BaseData
24

35

4-
class AudioData(BaseData):
5-
...
6+
class AudioData(BaseData, _NoCoercionMixin):
7+
class_name: Literal["AudioData"] = "AudioData"
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from labelbox.typing_imports import Literal
2+
from labelbox.utils import _NoCoercionMixin
13
from .base_data import BaseData
24

35

46
class ConversationData(BaseData):
5-
...
7+
class_name: Literal["ConversationData"] = "ConversationData"
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from labelbox.typing_imports import Literal
2+
from labelbox.utils import _NoCoercionMixin
13
from .base_data import BaseData
24

35

4-
class DicomData(BaseData):
5-
...
6+
class DicomData(BaseData, _NoCoercionMixin):
7+
class_name: Literal["DicomData"] = "DicomData"
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from labelbox.typing_imports import Literal
2+
from labelbox.utils import _NoCoercionMixin
13
from .base_data import BaseData
24

35

4-
class DocumentData(BaseData):
5-
...
6+
class DocumentData(BaseData, _NoCoercionMixin):
7+
class_name: Literal["DocumentData"] = "DocumentData"
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from labelbox.typing_imports import Literal
2+
from labelbox.utils import _NoCoercionMixin
13
from .base_data import BaseData
24

35

4-
class HTMLData(BaseData):
5-
...
6+
class HTMLData(BaseData, _NoCoercionMixin):
7+
class_name: Literal["HTMLData"] = "HTMLData"

labelbox/data/annotation_types/label.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
VideoClassificationAnnotation, VideoObjectAnnotation,
1212
DICOMObjectAnnotation)
1313
from .classification import ClassificationAnswer
14-
from .data import DicomData, VideoData, TextData, ImageData
14+
from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, MaskData, TextData, VideoData
1515
from .geometry import Mask
1616
from .metrics import ScalarMetric, ConfusionMatrixMetric
1717
from .types import Cuid
1818
from ..ontology import get_feature_schema_lookup
1919

20+
DataType = Union[AudioData, ConversationData, DicomData, DocumentData, HTMLData,
21+
ImageData, MaskData, TextData, VideoData]
22+
2023

2124
class Label(BaseModel):
2225
"""Container for holding data and annotations
@@ -38,7 +41,7 @@ class Label(BaseModel):
3841
extra: additional context
3942
"""
4043
uid: Optional[Cuid] = None
41-
data: Union[VideoData, ImageData, TextData, TiledImageData]
44+
data: DataType
4245
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
4346
ScalarMetric, ConfusionMatrixMetric]] = []
4447
extra: Dict[str, Any] = {}

labelbox/typing_imports.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
This module imports types that differ across python versions, so other modules
3+
don't have to worry about where they should be imported from.
4+
"""
5+
6+
import sys
7+
if sys.version_info >= (3, 8):
8+
from typing import Literal
9+
else:
10+
from typing_extensions import Literal

labelbox/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,26 @@ class _CamelCaseMixin(BaseModel):
3535
class Config:
3636
allow_population_by_field_name = True
3737
alias_generator = camel_case
38+
39+
40+
class _NoCoercionMixin:
41+
"""
42+
When using Unions in type annotations, pydantic will try to coerce the type
43+
of the object to the type of the first Union member. Which results in
44+
uninteded behavior.
45+
46+
This mixin uses a class_name discriminator field to prevent pydantic from
47+
corecing the type of the object. Add a class_name field to the class you
48+
want to discrimniate and use this mixin class to remove the discriminator
49+
when serializing the object.
50+
51+
Example:
52+
class ConversationData(BaseData, _NoCoercionMixin):
53+
class_name = Literal["ConversationData"] = "ConversationData"
54+
55+
"""
56+
57+
def dict(self, *args, **kwargs):
58+
res = super().dict(*args, **kwargs)
59+
res.pop('class_name')
60+
return res

tests/data/annotation_types/test_label.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
import labelbox.types as lb_types
34
from labelbox import OntologyBuilder, Tool, Classification as OClassification, Option
45
from labelbox.data.annotation_types import (ClassificationAnswer, Radio, Text,
56
ClassificationAnnotation,
@@ -181,3 +182,14 @@ def test_schema_assignment_confidence():
181182
])
182183

183184
assert label.annotations[0].confidence == 0.914
185+
186+
187+
def test_initialize_label_no_coercion():
188+
global_key = 'global-key'
189+
ner_annotation = lb_types.ObjectAnnotation(
190+
name="ner",
191+
value=lb_types.ConversationEntity(start=0, end=8, message_id="4"))
192+
label = Label(data=lb_types.ConversationData(global_key=global_key),
193+
annotations=[ner_annotation])
194+
assert isinstance(label.data, lb_types.ConversationData)
195+
assert label.data.global_key == global_key

0 commit comments

Comments
 (0)