Skip to content

Commit 408cf7a

Browse files
authored
[PLT-71] Fix relationships imports (#1421)
1 parent fcc1821 commit 408cf7a

File tree

9 files changed

+105
-52
lines changed

9 files changed

+105
-52
lines changed

labelbox/data/annotation_types/base_annotation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from uuid import UUID
2+
from uuid import UUID, uuid4
33
from typing import Any, Dict, Optional
44
from labelbox import pydantic_compat
55

@@ -15,4 +15,4 @@ class BaseAnnotation(FeatureSchema, abc.ABC):
1515
def __init__(self, **data):
1616
super().__init__(**data)
1717
extra_uuid = data.get("extra", {}).get("uuid")
18-
self._uuid = data.get("_uuid") or extra_uuid or None
18+
self._uuid = data.get("_uuid") or extra_uuid or uuid4()
Lines changed: 72 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1+
import copy
12
import logging
23
import uuid
3-
from typing import Any, Dict, Generator, Iterable
4+
from collections import defaultdict, deque
5+
from typing import Any, Deque, Dict, Generator, Iterable, List, Set, Union
6+
7+
from labelbox.data.annotation_types.annotation import ObjectAnnotation
8+
from labelbox.data.annotation_types.classification.classification import (
9+
ClassificationAnnotation,)
10+
from labelbox.data.annotation_types.metrics.confusion_matrix import (
11+
ConfusionMatrixMetric,)
12+
from labelbox.data.annotation_types.metrics.scalar import ScalarMetric
13+
from labelbox.data.annotation_types.video import VideoMaskAnnotation
414

515
from ...annotation_types.collection import LabelCollection, LabelGenerator
616
from ...annotation_types.relationship import RelationshipAnnotation
@@ -42,51 +52,69 @@ def serialize(
4252
Returns:
4353
A generator for accessing the ndjson representation of the data
4454
"""
45-
used_annotation_uuids = set()
46-
for label in labels:
47-
annotation_uuid_to_generated_uuid_lookup = {}
48-
# UUIDs are private properties used to enhance UX when defining relationships.
49-
# They are created for all annotations, but only utilized for relationships.
50-
# To avoid overwriting, UUIDs must be unique across labels.
51-
# Non-relationship annotation UUIDs are dropped (server-side generation will occur).
52-
# For relationship annotations, new UUIDs are generated and stored in a lookup table.
53-
for annotation in label.annotations:
54-
if isinstance(annotation, RelationshipAnnotation):
55-
source_uuid = annotation.value.source._uuid
56-
target_uuid = annotation.value.target._uuid
55+
used_uuids: Set[uuid.UUID] = set()
5756

58-
if (len(
59-
used_annotation_uuids.intersection(
60-
{source_uuid, target_uuid})) > 0):
61-
new_source_uuid = uuid.uuid4()
62-
new_target_uuid = uuid.uuid4()
63-
64-
annotation_uuid_to_generated_uuid_lookup[
65-
source_uuid] = new_source_uuid
66-
annotation_uuid_to_generated_uuid_lookup[
67-
target_uuid] = new_target_uuid
68-
annotation.value.source._uuid = new_source_uuid
69-
annotation.value.target._uuid = new_target_uuid
70-
else:
71-
annotation_uuid_to_generated_uuid_lookup[
72-
source_uuid] = source_uuid
73-
annotation_uuid_to_generated_uuid_lookup[
74-
target_uuid] = target_uuid
75-
used_annotation_uuids.add(annotation._uuid)
57+
relationship_uuids: Dict[uuid.UUID,
58+
Deque[uuid.UUID]] = defaultdict(deque)
7659

60+
# UUIDs are private properties used to enhance UX when defining relationships.
61+
# They are created for all annotations, but only utilized for relationships.
62+
# To avoid overwriting, UUIDs must be unique across labels.
63+
# Non-relationship annotation UUIDs are regenerated when they are reused.
64+
# For relationship annotations, during first pass, we update the UUIDs of the source and target annotations.
65+
# During the second pass, we update the UUIDs of the annotations referenced by the relationship annotations.
66+
for label in labels:
67+
uuid_safe_annotations: List[Union[
68+
ClassificationAnnotation,
69+
ObjectAnnotation,
70+
VideoMaskAnnotation,
71+
ScalarMetric,
72+
ConfusionMatrixMetric,
73+
RelationshipAnnotation,
74+
]] = []
75+
# First pass to get all RelatiohnshipAnnotaitons
76+
# and update the UUIDs of the source and target annotations
77+
for relationship_annotation in (
78+
annotation for annotation in label.annotations
79+
if isinstance(annotation, RelationshipAnnotation)):
80+
if relationship_annotation in uuid_safe_annotations:
81+
relationship_annotation = copy.deepcopy(
82+
relationship_annotation)
83+
new_source_uuid = uuid.uuid4()
84+
new_target_uuid = uuid.uuid4()
85+
relationship_uuids[relationship_annotation.value.source.
86+
_uuid].append(new_source_uuid)
87+
relationship_uuids[relationship_annotation.value.target.
88+
_uuid].append(new_target_uuid)
89+
relationship_annotation.value.source._uuid = new_source_uuid
90+
relationship_annotation.value.target._uuid = new_target_uuid
91+
if relationship_annotation._uuid in used_uuids:
92+
relationship_annotation._uuid = uuid.uuid4()
93+
used_uuids.add(relationship_annotation._uuid)
94+
uuid_safe_annotations.append(relationship_annotation)
95+
# Second pass to update UUIDs for annotations referenced by RelationshipAnnotations
7796
for annotation in label.annotations:
78-
if (not isinstance(annotation, RelationshipAnnotation) and
79-
hasattr(annotation, "_uuid")):
80-
annotation._uuid = annotation_uuid_to_generated_uuid_lookup.get(
81-
annotation._uuid, annotation._uuid)
97+
if not isinstance(annotation, RelationshipAnnotation):
98+
if hasattr(annotation, "_uuid"):
99+
if annotation in uuid_safe_annotations:
100+
annotation = copy.deepcopy(annotation)
101+
next_uuids = relationship_uuids[annotation._uuid]
102+
if len(next_uuids) > 0:
103+
annotation._uuid = next_uuids.popleft()
82104

83-
for example in NDLabel.from_common(labels):
84-
annotation_uuid = getattr(example, "uuid", None)
105+
if annotation._uuid in used_uuids:
106+
annotation._uuid = uuid.uuid4()
107+
used_uuids.add(annotation._uuid)
108+
uuid_safe_annotations.append(annotation)
109+
label.annotations = uuid_safe_annotations
110+
for example in NDLabel.from_common([label]):
111+
annotation_uuid = getattr(example, "uuid", None)
85112

86-
res = example.dict(
87-
by_alias=True,
88-
exclude={"uuid"} if annotation_uuid == "None" else None)
89-
for k, v in list(res.items()):
90-
if k in IGNORE_IF_NONE and v is None:
91-
del res[k]
92-
yield res
113+
res = example.dict(
114+
by_alias=True,
115+
exclude={"uuid"} if annotation_uuid == "None" else None,
116+
)
117+
for k, v in list(res.items()):
118+
if k in IGNORE_IF_NONE and v is None:
119+
del res[k]
120+
yield res

tests/data/serialization/ndjson/test_checklist.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_serialization_min():
3232
}
3333
serialized = NDJsonConverter.serialize([label])
3434
res = next(serialized)
35-
35+
res.pop("uuid")
3636
assert res == expected
3737

3838
deserialized = NDJsonConverter.deserialize([res])
@@ -112,6 +112,7 @@ def test_serialization_with_classification():
112112
serialized = NDJsonConverter.serialize([label])
113113
res = next(serialized)
114114

115+
res.pop("uuid")
115116
assert res == expected
116117

117118
deserialized = NDJsonConverter.deserialize([res])
@@ -195,6 +196,7 @@ def test_serialization_with_classification_double_nested():
195196
serialized = NDJsonConverter.serialize([label])
196197
res = next(serialized)
197198

199+
res.pop("uuid")
198200
assert res == expected
199201

200202
deserialized = NDJsonConverter.deserialize([res])
@@ -274,6 +276,7 @@ def test_serialization_with_classification_double_nested_2():
274276

275277
serialized = NDJsonConverter.serialize([label])
276278
res = next(serialized)
279+
res.pop("uuid")
277280
assert res == expected
278281

279282
deserialized = NDJsonConverter.deserialize([res])

tests/data/serialization/ndjson/test_conversation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
[free_text_label, free_text_ndjson]])
8484
def test_message_based_radio_classification(label, ndjson):
8585
serialized_label = list(NDJsonConverter().serialize(label))
86+
serialized_label[0].pop('uuid')
8687
assert serialized_label == ndjson
8788

8889
deserialized_label = list(NDJsonConverter().deserialize(ndjson))

tests/data/serialization/ndjson/test_document.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def test_pdf_with_name_only():
6969

7070
def test_pdf_bbox_serialize():
7171
serialized = list(NDJsonConverter.serialize(bbox_labels))
72+
serialized[0].pop('uuid')
7273
assert serialized == bbox_ndjson
7374

7475

tests/data/serialization/ndjson/test_image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_mask_from_arr():
9090
],
9191
data=ImageData(uid="0" * 25))
9292
res = next(NDJsonConverter.serialize([label]))
93+
res.pop("uuid")
9394
assert res == {
9495
"classifications": [],
9596
"schemaId": "1" * 25,

tests/data/serialization/ndjson/test_radio.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def test_serialization_with_radio_min():
3434
serialized = NDJsonConverter.serialize([label])
3535
res = next(serialized)
3636

37+
res.pop("uuid")
3738
assert res == expected
3839

3940
deserialized = NDJsonConverter.deserialize([res])
@@ -85,6 +86,7 @@ def test_serialization_with_radio_classification():
8586

8687
serialized = NDJsonConverter.serialize([label])
8788
res = next(serialized)
89+
res.pop("uuid")
8890
assert res == expected
8991

9092
deserialized = NDJsonConverter.deserialize([res])

tests/data/serialization/ndjson/test_relationship.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,37 @@
11
import json
2-
import pytest
32
from uuid import uuid4
43

4+
import pytest
5+
56
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
67

78

89
def test_relationship():
9-
with open('tests/data/assets/ndjson/relationship_import.json', 'r') as file:
10+
with open("tests/data/assets/ndjson/relationship_import.json", "r") as file:
1011
data = json.load(file)
1112

1213
res = list(NDJsonConverter.deserialize(data))
1314
res = list(NDJsonConverter.serialize(res))
15+
assert len(res) == len(data)
16+
17+
res_relationship_annotation = [
18+
annot for annot in res if "relationship" in annot
19+
][0]
20+
res_source_and_target = [
21+
annot for annot in res if "relationship" not in annot
22+
]
23+
assert res_relationship_annotation
1424

15-
assert res == data
25+
assert res_relationship_annotation["relationship"]["source"] in [
26+
annot["uuid"] for annot in res_source_and_target
27+
]
28+
assert res_relationship_annotation["relationship"]["target"] in [
29+
annot["uuid"] for annot in res_source_and_target
30+
]
1631

1732

1833
def test_relationship_nonexistent_object():
19-
with open('tests/data/assets/ndjson/relationship_import.json', 'r') as file:
34+
with open("tests/data/assets/ndjson/relationship_import.json", "r") as file:
2035
data = json.load(file)
2136

2237
relationship_annotation = data[2]
@@ -30,7 +45,7 @@ def test_relationship_nonexistent_object():
3045

3146

3247
def test_relationship_duplicate_uuids():
33-
with open('tests/data/assets/ndjson/relationship_import.json', 'r') as file:
48+
with open("tests/data/assets/ndjson/relationship_import.json", "r") as file:
3449
data = json.load(file)
3550

3651
source, target = data[0], data[1]

tests/data/serialization/ndjson/test_video.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def test_video_classification_global_subclassifications():
8787

8888
serialized = NDJsonConverter.serialize([label])
8989
res = [x for x in serialized]
90+
for annotations in res:
91+
annotations.pop("uuid")
9092
assert res == [expected_first_annotation, expected_second_annotation]
9193

9294
deserialized = NDJsonConverter.deserialize(res)

0 commit comments

Comments
 (0)