-
Notifications
You must be signed in to change notification settings - Fork 20
/
dataset.py
89 lines (80 loc) · 2.51 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import cv2
import numpy as np
from torchvision.datasets import VOCSegmentation
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
VOC_CLASSES = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"potted plant",
"sheep",
"sofa",
"train",
"tv/monitor",
]
VOC_COLORMAP = [
[0, 0, 0],
[128, 0, 0],
[0, 128, 0],
[128, 128, 0],
[0, 0, 128],
[128, 0, 128],
[0, 128, 128],
[128, 128, 128],
[64, 0, 0],
[192, 0, 0],
[64, 128, 0],
[192, 128, 0],
[64, 0, 128],
[192, 0, 128],
[64, 128, 128],
[192, 128, 128],
[0, 64, 0],
[128, 64, 0],
[0, 192, 0],
[128, 192, 0],
[0, 64, 128],
]
class PascalVOCSearchDataset(VOCSegmentation):
def __init__(self, root="~/data/pascal_voc", image_set="train", download=True, transform=None):
super().__init__(root=root, image_set=image_set, download=download, transform=transform)
@staticmethod
def _convert_to_segmentation_mask(mask):
# This function converts a mask from the Pascal VOC format to the format required by AutoAlbument.
#
# Pascal VOC uses an RGB image to encode the segmentation mask for that image. RGB values of a pixel
# encode the pixel's class.
#
# AutoAlbument requires a segmentation mask to be a NumPy array with the shape [height, width, num_classes].
# Each channel in this mask should encode values for a single class. Pixel in a mask channel should have
# a value of 1.0 if the pixel of the image belongs to this class and 0.0 otherwise.
height, width = mask.shape[:2]
segmentation_mask = np.zeros((height, width, len(VOC_COLORMAP)), dtype=np.float32)
for label_index, label in enumerate(VOC_COLORMAP):
segmentation_mask[:, :, label_index] = np.all(mask == label, axis=-1).astype(float)
return segmentation_mask
def __getitem__(self, index):
image = cv2.imread(self.images[index])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(self.masks[index])
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
mask = self._convert_to_segmentation_mask(mask)
if self.transform is not None:
transformed = self.transform(image=image, mask=mask)
image = transformed["image"]
mask = transformed["mask"]
return image, mask