Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion labelbox/data/annotation_types/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class BaseAnnotation(FeatureSchema, abc.ABC):
extra: Dict[str, Any] = {}


class ClassificationAnnotation(BaseAnnotation):
class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin):
"""Classification annotations (non localized)

>>> ClassificationAnnotation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


# TODO: Replace when pydantic adds support for unions that don't coerce types
class _TempName(BaseModel):
class _TempName(ConfidenceMixin, BaseModel):
name: str

def dict(self, *args, **kwargs):
Expand Down Expand Up @@ -43,7 +43,7 @@ def dict(self, *args, **kwargs) -> Dict[str, str]:
return res


class Radio(BaseModel):
class Radio(ConfidenceMixin, BaseModel):
""" A classification with only one selected option allowed

>>> Radio(answer = ClassificationAnswer(name = "dog"))
Expand All @@ -62,7 +62,7 @@ class Checklist(_TempName):
answer: List[ClassificationAnswer]


class Text(BaseModel):
class Text(ConfidenceMixin, BaseModel):
""" Free form text

>>> Text(answer = "some text answer")
Expand Down
44 changes: 30 additions & 14 deletions labelbox/data/serialization/ndjson/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,33 @@ def from_common(cls, radio: Radio, name: str,
class NDText(NDAnnotation, NDTextSubclass):

@classmethod
def from_common(cls, text: Text, name: str, feature_schema_id: Cuid,
extra: Dict[str, Any], data: Union[TextData,
ImageData]) -> "NDText":
def from_common(cls,
text: Text,
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
data: Union[TextData, ImageData],
confidence: Optional[float] = None) -> "NDText":
return cls(
answer=text.answer,
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
confidence=confidence,
)


class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported):

@classmethod
def from_common(
cls, checklist: Checklist, name: str, feature_schema_id: Cuid,
extra: Dict[str, Any], data: Union[VideoData, TextData,
ImageData]) -> "NDChecklist":
def from_common(cls,
checklist: Checklist,
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
data: Union[VideoData, TextData, ImageData],
confidence: Optional[float] = None) -> "NDChecklist":
return cls(answer=[
NDFeature(name=answer.name,
schema_id=answer.feature_schema_id,
Expand All @@ -149,23 +157,29 @@ def from_common(
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
frames=extra.get('frames'))
frames=extra.get('frames'),
confidence=confidence)


class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported):

@classmethod
def from_common(cls, radio: Radio, name: str, feature_schema_id: Cuid,
extra: Dict[str, Any], data: Union[VideoData, TextData,
ImageData]) -> "NDRadio":
def from_common(cls,
radio: Radio,
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
data: Union[VideoData, TextData, ImageData],
confidence: Optional[float] = None) -> "NDRadio":
return cls(answer=NDFeature(name=radio.answer.name,
schema_id=radio.answer.feature_schema_id,
confidence=radio.answer.confidence),
data_row=DataRow(id=data.uid, global_key=data.global_key),
name=name,
schema_id=feature_schema_id,
uuid=extra.get('uuid'),
frames=extra.get('frames'))
frames=extra.get('frames'),
confidence=confidence)


class NDSubclassification:
Expand Down Expand Up @@ -212,7 +226,8 @@ def to_common(
value=annotation.to_common(),
name=annotation.name,
feature_schema_id=annotation.schema_id,
extra={'uuid': annotation.uuid})
extra={'uuid': annotation.uuid},
confidence=annotation.confidence)
if getattr(annotation, 'frames', None) is None:
return [common]
results = []
Expand All @@ -235,7 +250,8 @@ def from_common(
)
return classify_obj.from_common(annotation.value, annotation.name,
annotation.feature_schema_id,
annotation.extra, data)
annotation.extra, data,
annotation.confidence)

@staticmethod
def lookup_classification(
Expand Down
12 changes: 12 additions & 0 deletions tests/data/annotation_types/test_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from labelbox.data.annotation_types.classification.classification import Text


def test_text():
text_entity = Text(answer="good job")
assert text_entity.answer == "good job"


def test_text_confidence():
text_entity = Text(answer="good job", confidence=0.5)
assert text_entity.answer == "good job"
assert text_entity.confidence == 0.5
25 changes: 0 additions & 25 deletions tests/data/assets/ndjson/text_import.json

This file was deleted.

15 changes: 0 additions & 15 deletions tests/data/assets/ndjson/text_import_name_only.json

This file was deleted.

43 changes: 43 additions & 0 deletions tests/data/serialization/ndjson/test_checklist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio
from labelbox.data.annotation_types.data.text import TextData
from labelbox.data.annotation_types.label import Label

from labelbox.data.serialization.ndjson.converter import NDJsonConverter


def test_serialization():
label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d",
data=TextData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
text="This is a test",
),
annotations=[
ClassificationAnnotation(
name="checkbox_question_geo",
confidence=0.5,
value=Checklist(answer=[
ClassificationAnswer(name="first_answer"),
ClassificationAnswer(name="second_answer")
]))
])

serialized = NDJsonConverter.serialize([label])

res = next(serialized)
assert res['confidence'] == 0.5
assert res['name'] == "checkbox_question_geo"
assert res['answer'][0]['name'] == "first_answer"
assert res['answer'][1]['name'] == "second_answer"
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"

deserialized = NDJsonConverter.deserialize([res])
res = next(deserialized)
annotation = res.annotations[0]
assert annotation.confidence == 0.5

annotation_value = annotation.value
assert type(annotation_value) is Checklist
assert annotation_value.answer[0].name == "first_answer"
assert annotation_value.answer[1].name == "second_answer"
39 changes: 39 additions & 0 deletions tests/data/serialization/ndjson/test_radio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import json
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
from labelbox.data.annotation_types.classification.classification import ClassificationAnswer, Radio
from labelbox.data.annotation_types.data.text import TextData
from labelbox.data.annotation_types.label import Label

from labelbox.data.serialization.ndjson.converter import NDJsonConverter


def test_serialization():
label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d",
data=TextData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
text="This is a test",
),
annotations=[
ClassificationAnnotation(
name="radio_question_geo",
confidence=0.5,
value=Radio(answer=ClassificationAnswer(
confidence=0.6, name="first_radio_answer")))
])

serialized = NDJsonConverter.serialize([label])
res = next(serialized)
assert res['confidence'] == 0.5
assert res['name'] == "radio_question_geo"
assert res['answer']['name'] == "first_radio_answer"
assert res['answer']['confidence'] == 0.6
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"

deserialized = NDJsonConverter.deserialize([res])
res = next(deserialized)
annotation = res.annotations[0]
assert annotation.confidence == 0.5

annotation_value = annotation.value
assert type(annotation_value) is Radio
assert annotation_value.answer.name == "first_radio_answer"
43 changes: 30 additions & 13 deletions tests/data/serialization/ndjson/test_text.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,37 @@
import json
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
from labelbox.data.annotation_types.classification.classification import ClassificationAnswer, Radio, Text
from labelbox.data.annotation_types.data.text import TextData
from labelbox.data.annotation_types.label import Label

from labelbox.data.serialization.ndjson.converter import NDJsonConverter


def test_text():
with open('tests/data/assets/ndjson/text_import.json', 'r') as file:
data = json.load(file)
res = list(NDJsonConverter.deserialize(data))
res = list(NDJsonConverter.serialize(res))
assert res == data
def test_serialization():
label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d",
data=TextData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
text="This is a test",
),
annotations=[
ClassificationAnnotation(
name="radio_question_geo",
confidence=0.5,
value=Text(answer="first_radio_answer"))
])

serialized = NDJsonConverter.serialize([label])
res = next(serialized)
assert res['confidence'] == 0.5
assert res['name'] == "radio_question_geo"
assert res['answer'] == "first_radio_answer"
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"

def test_text_name_only():
with open('tests/data/assets/ndjson/text_import_name_only.json',
'r') as file:
data = json.load(file)
res = list(NDJsonConverter.deserialize(data))
res = list(NDJsonConverter.serialize(res))
assert res == data
deserialized = NDJsonConverter.deserialize([res])
res = next(deserialized)
annotation = res.annotations[0]
assert annotation.confidence == 0.5

annotation_value = annotation.value
assert type(annotation_value) is Text
assert annotation_value.answer == "first_radio_answer"
8 changes: 6 additions & 2 deletions tests/integration/annotation_import/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def model_run_with_model_run_data_rows(client, configured_project,
labels = wait_for_label_processing(configured_project)
label_ids = [label.uid for label in labels]
model_run.upsert_labels(label_ids)
time.sleep(3)
time.sleep(300)
yield model_run
model_run.delete()
# TODO: Delete resources when that is possible ..
Expand All @@ -670,6 +670,11 @@ def model_run_with_all_project_labels(client, configured_project,
wait_for_label_processing):
configured_project.enable_model_assisted_labeling()

data_row_ids = configured_project.data_row_ids

configured_project._wait_until_data_rows_are_processed(
data_row_ids=data_row_ids)

upload_task = LabelImport.create_from_objects(
client, configured_project.uid, f"label-import-{uuid.uuid4()}",
model_run_predictions)
Expand All @@ -680,7 +685,6 @@ def model_run_with_all_project_labels(client, configured_project,
) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}"
wait_for_label_processing(configured_project)
model_run.upsert_labels(project_id=configured_project.uid)
time.sleep(3)
yield model_run
model_run.delete()
# TODO: Delete resources when that is possible ..
Expand Down