@@ -184,6 +184,9 @@ def __init__(self, odgt, opt, max_sample=-1, start_idx=-1, end_idx=-1):
184184 # max down sampling rate of network to avoid rounding during conv or pooling
185185 self .padding_constant = opt .padding_constant
186186
187+ # how many layers used to do predictions
188+ self .nr_layers = 4
189+
187190 # mean and std
188191 self .img_transform = transforms .Compose ([
189192 transforms .Normalize (mean = [102.9801 , 115.9465 , 122.7717 ], std = [1. , 1. , 1. ])
@@ -207,12 +210,13 @@ def __getitem__(self, index):
207210 image_path = os .path .join (self .root_dataset , this_record ['fpath_img' ])
208211 segm_path = os .path .join (self .root_dataset , this_record ['fpath_segm' ])
209212 img = imread (image_path , mode = 'RGB' )
210- img = img [:, :, ::- 1 ] # BGR to RGB !!!
211- segm = imread (segm_path )
213+ img = img [:, :, ::- 1 ] # RGB to BGR !!!
214+ segm_ori = imread (segm_path )
212215
213216 ori_height , ori_width , _ = img .shape
214217
215218 img_resized_list = []
219+ segm_gt_list = []
216220 for this_short_size in self .imgSize :
217221 # calculate target height and width
218222 scale = min (this_short_size / float (min (ori_height , ori_width )),
@@ -234,15 +238,31 @@ def __getitem__(self, index):
234238 img_resized = torch .unsqueeze (img_resized , 0 )
235239 img_resized_list .append (img_resized )
236240
237- segm = torch .from_numpy (segm .astype (np .int )).long ()
241+ # construct ground-truth label map for each layer
242+ standard_segm_h , standard_segm_w = segm_ori .shape [0 ], segm_ori .shape [1 ]
243+ segm = segm_ori .copy ()
244+ for id_layer in reversed (range (self .nr_layers )):
245+ # downsampling first
246+ this_segm = imresize (segm , (target_height // (2 ** (2 + id_layer )), target_width // (2 ** (2 + id_layer ))),
247+ interp = 'nearest' )
248+ # upsampling the downsampled segm
249+ this_segm_upsampled = imresize (this_segm , (standard_segm_h , standard_segm_w ), interp = 'nearest' )
250+ # for those labels that are still correct, we predict them at this layer
251+ this_segm_gt = this_segm_upsampled * (segm == this_segm_upsampled )
252+ segm_gt_list .append (torch .from_numpy (this_segm_gt .astype (np .int )).long ()- 1 )
253+ # remove already assigned labels (keep unassigned labels)
254+ segm = segm * (this_segm_gt == 0 )
255+
256+ segm_ori = torch .from_numpy (segm_ori .astype (np .int )).long ()
238257
239- batch_segms = torch .unsqueeze (segm , 0 )
258+ batch_segms = torch .unsqueeze (segm_ori , 0 )
240259
241260 batch_segms = batch_segms - 1 # label from -1 to 149
242261 output = dict ()
243262 output ['img_ori' ] = img .copy ()
244263 output ['img_data' ] = [x .contiguous () for x in img_resized_list ]
245264 output ['seg_label' ] = batch_segms .contiguous ()
265+ output ['seg_gt_list' ] = segm_gt_list
246266 output ['info' ] = this_record ['fpath_img' ]
247267 return output
248268
0 commit comments