-
Notifications
You must be signed in to change notification settings - Fork 484
/
cityscape_segmentation.py
150 lines (126 loc) · 6.56 KB
/
cityscape_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import cv2
import numpy as np
from PIL import Image, ImageColor
from super_gradients.common.object_names import Datasets
from super_gradients.common.registry.registry import register_dataset
from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
# TODO - ADD COARSE DATA - right now cityscapes dataset includes fine annotations. It's optional to use extra coarse
# annotations.
# label for background and labels to ignore during training and evaluation.
CITYSCAPES_IGNORE_LABEL = 19
@register_dataset(Datasets.CITYSCAPES_DATASET)
class CityscapesDataset(SegmentationDataSet):
"""
CityscapesDataset - Segmentation Data Set Class for Cityscapes Segmentation Data Set,
main resolution of dataset: (2048 x 1024).
Not all the original labels are used for training and evaluation, according to cityscape paper:
"Classes that are too rare are excluded from our benchmark, leaving 19 classes for evaluation".
For more details about the dataset labels format see:
https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py
To use this Dataset you need to:
- Download cityscape dataset (https://www.cityscapes-dataset.com/downloads/)
root_dir (in recipe default to /data/cityscapes)
├─── gtFine
│ ├── test
│ │ ├── berlin
│ │ │ ├── berlin_000000_000019_gtFine_color.png
│ │ │ ├── berlin_000000_000019_gtFine_instanceIds.png
│ │ │ └── ...
│ │ ├── bielefeld
│ │ │ └── ...
│ │ └── ...
│ ├─── train
│ │ └── ...
│ └─── val
│ └── ...
└─── leftImg8bit
├── test
│ └── ...
├─── train
│ └── ...
└─── val
└── ...
- Download metadata folder (https://deci-pretrained-models.s3.amazonaws.com/cityscape_lists.zip)
lists
├── labels.csv
├── test.lst
├── train.lst
├── trainval.lst
└── val.lst
- Move Metadata folder to the Cityscape folder
root_dir (in recipe default to /data/cityscapes)
├─── gtFine
│ └── ...
├─── leftImg8bit
│ └── ...
└─── lists
└── ...
Example:
>> CityscapesDataset(root_dir='.../root_dir', list_file='lists/train.lst', labels_csv_path='lists/labels.csv', ...)
"""
def __init__(self, root_dir: str, list_file: str, labels_csv_path: str, **kwargs):
"""
:param root: Absolute path to root directory of the dataset.
:param list_file: List file that contains names of images to load, line format: <image_path> <label_path>. The path is relative to root.
:param labels_csv_path: Path to csv file, with labels metadata and mapping. The path is relative to root.
:param kwargs: Any hyper params required for the dataset, i.e img_size, crop_size, cache_images
"""
self.root_dir = root_dir
super().__init__(root_dir, list_file=list_file, **kwargs)
# labels dataframe for labels metadata.
self.labels_data = np.recfromcsv(os.path.join(self.root_dir, labels_csv_path), dtype="<i8,U20,<i8,<i8,U12,<i8,?,?,U7", comments="&")
# map vector to map ground-truth labels to train labels
self.labels_map = self.labels_data.field("trainid")
# class names
self.classes = self.labels_data.field("name")[np.logical_not(self.labels_data.field("ignoreineval"))].tolist()
# color palette for visualization
self.train_id_color_palette = self._create_color_palette()
def _generate_samples_and_targets(self):
"""
override _generate_samples_and_targets function, to parse list file.
line format of list file: <image_path> <label_path>
"""
with open(os.path.join(self.root_dir, self.list_file_path)) as f:
img_list = [line.strip().split() for line in f]
for image_path, label_path in img_list:
self.samples_targets_tuples_list.append((os.path.join(self.root, image_path), os.path.join(self.root, label_path)))
super(CityscapesDataset, self)._generate_samples_and_targets()
def target_loader(self, label_path: str) -> Image:
"""
Override target_loader function, load the labels mask image.
:param label_path: Path to the label image.
:return: The mask image created from the array, with converted class labels.
"""
# assert that is a png file, other file types might alter the class labels value.
assert os.path.splitext(label_path)[-1].lower() == ".png"
label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
# map ground-truth ids to train ids
label = self.labels_map[label].astype(np.uint8)
return Image.fromarray(label, "L")
def _create_color_palette(self):
"""
Create color pallete for visualizing the segmentation masks
:return: list of rgb color values
"""
palette = []
hex_colors = self.labels_data.field("color")[np.logical_not(self.labels_data.field("ignoreineval"))].tolist()
for hex_color in hex_colors:
rgb_color = ImageColor.getcolor(hex_color, "RGB")
palette += [x for x in rgb_color]
return palette
def get_train_ids_color_palette(self):
return self.train_id_color_palette
@staticmethod
def target_transform(target):
"""
target_transform - Transforms the sample image
This function overrides the original function from SegmentationDataSet and changes target pixels with value
255 to value = CITYSCAPES_IGNORE_LABEL. This was done since current IoU metric from torchmetrics does not
support such a high ignore label value (crashed on OOM)
:param target: The target mask to transform
:return: The transformed target mask
"""
out = SegmentationDataSet.target_transform(target)
out[out == 255] = CITYSCAPES_IGNORE_LABEL
return out