generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 150
/
voc_parser.py
170 lines (139 loc) · 5.18 KB
/
voc_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
__all__ = ["voc", "VOCBBoxParser", "VOCMaskParser"]
import xml.etree.ElementTree as ET
from icevision.imports import *
from icevision.utils import *
from icevision.core import *
from icevision.parsers.parser import *
def voc(
annotations_dir: Union[str, Path],
images_dir: Union[str, Path],
class_map: Optional[ClassMap] = None,
masks_dir: Optional[Union[str, Path]] = None,
idmap: Optional[IDMap] = None,
):
logger.warning(
"This function will be deprecated, instantiate the concrete "
"classes instead: `VOCBBoxParser`, `VOCMaskParser`"
)
if not masks_dir:
return VOCBBoxParser(
annotations_dir=annotations_dir,
images_dir=images_dir,
class_map=class_map,
idmap=idmap,
)
else:
return VOCMaskParser(
annotations_dir=annotations_dir,
images_dir=images_dir,
masks_dir=masks_dir,
class_map=class_map,
idmap=idmap,
)
# TODO: Rename to VOCBBoxParser?
class VOCBBoxParser(Parser):
def __init__(
self,
annotations_dir: Union[str, Path],
images_dir: Union[str, Path],
class_map: Optional[ClassMap] = None,
idmap: Optional[IDMap] = None,
):
super().__init__(template_record=self.template_record(), idmap=idmap)
self.class_map = class_map or ClassMap().unlock()
self.images_dir = Path(images_dir)
self.annotations_dir = Path(annotations_dir)
self.annotation_files = get_files(self.annotations_dir, extensions=[".xml"])
def __len__(self):
return len(self.annotation_files)
def __iter__(self):
yield from self.annotation_files
def template_record(self) -> BaseRecord:
return BaseRecord(
(
FilepathRecordComponent(),
InstancesLabelsRecordComponent(),
BBoxesRecordComponent(),
)
)
def record_id(self, o) -> Hashable:
return str(Path(self._filename).stem)
def prepare(self, o):
tree = ET.parse(str(o))
self._root = tree.getroot()
self._filename = self._root.find("filename").text
self._size = self._root.find("size")
def parse_fields(self, o, record, is_new):
if is_new:
record.set_filepath(self.filepath(o))
record.set_img_size(self.img_size(o))
record.detection.set_class_map(self.class_map)
record.detection.add_labels(self.labels(o))
record.detection.add_bboxes(self.bboxes(o))
def filepath(self, o) -> Union[str, Path]:
return self.images_dir / self._filename
def img_size(self, o) -> ImgSize:
width = int(self._size.find("width").text)
height = int(self._size.find("height").text)
return ImgSize(width=width, height=height)
def labels(self, o) -> List[Hashable]:
labels = []
for object in self._root.iter("object"):
label = object.find("name").text
labels.append(label)
return labels
def bboxes(self, o) -> List[BBox]:
def to_int(x):
return int(float(x))
bboxes = []
for object in self._root.iter("object"):
xml_bbox = object.find("bndbox")
xmin = to_int(xml_bbox.find("xmin").text)
ymin = to_int(xml_bbox.find("ymin").text)
xmax = to_int(xml_bbox.find("xmax").text)
ymax = to_int(xml_bbox.find("ymax").text)
bbox = BBox.from_xyxy(xmin, ymin, xmax, ymax)
bboxes.append(bbox)
return bboxes
class VOCMaskParser(VOCBBoxParser):
def __init__(
self,
annotations_dir: Union[str, Path],
images_dir: Union[str, Path],
masks_dir: Union[str, Path],
class_map: Optional[ClassMap] = None,
idmap: Optional[IDMap] = None,
):
super().__init__(
annotations_dir=annotations_dir,
images_dir=images_dir,
class_map=class_map,
idmap=idmap,
)
self.masks_dir = masks_dir
self.mask_files = get_image_files(masks_dir)
self._record_id2maskfile = {self.record_id_mask(o): o for o in self.mask_files}
# filter annotations
masks_ids = frozenset(self._record_id2maskfile.keys())
self._intersection = []
for item in super().__iter__():
super().prepare(item)
if super().record_id(item) in masks_ids:
self._intersection.append(item)
def __len__(self):
return len(self._intersection)
def __iter__(self):
yield from self._intersection
def template_record(self) -> BaseRecord:
record = super().template_record()
record.add_component(InstanceMasksRecordComponent())
return record
def record_id_mask(self, o) -> Hashable:
"""Should return the same as `record_id` from parent parser."""
return str(Path(o).stem)
def parse_fields(self, o, record, is_new):
super().parse_fields(o, record, is_new=is_new)
record.detection.add_masks(self.masks(o))
def masks(self, o) -> List[Mask]:
mask_file = self._record_id2maskfile[self.record_id(o)]
return [VocMaskFile(mask_file)]