Skip to content

Commit 3ad95d1

Browse files
author
Tete Xiao
committed
add eval for exp
1 parent 7fc7b75 commit 3ad95d1

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

dataset.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def __init__(self, odgt, opt, max_sample=-1, start_idx=-1, end_idx=-1):
184184
# max down sampling rate of network to avoid rounding during conv or pooling
185185
self.padding_constant = opt.padding_constant
186186

187+
# how many layers used to do predictions
188+
self.nr_layers = 4
189+
187190
# mean and std
188191
self.img_transform = transforms.Compose([
189192
transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])
@@ -207,12 +210,13 @@ def __getitem__(self, index):
207210
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
208211
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
209212
img = imread(image_path, mode='RGB')
210-
img = img[:, :, ::-1] # BGR to RGB!!!
211-
segm = imread(segm_path)
213+
img = img[:, :, ::-1] # RGB to BGR!!!
214+
segm_ori = imread(segm_path)
212215

213216
ori_height, ori_width, _ = img.shape
214217

215218
img_resized_list = []
219+
segm_gt_list = []
216220
for this_short_size in self.imgSize:
217221
# calculate target height and width
218222
scale = min(this_short_size / float(min(ori_height, ori_width)),
@@ -234,15 +238,31 @@ def __getitem__(self, index):
234238
img_resized = torch.unsqueeze(img_resized, 0)
235239
img_resized_list.append(img_resized)
236240

237-
segm = torch.from_numpy(segm.astype(np.int)).long()
241+
# construct ground-truth label map for each layer
242+
standard_segm_h, standard_segm_w = segm_ori.shape[0], segm_ori.shape[1]
243+
segm = segm_ori.copy()
244+
for id_layer in reversed(range(self.nr_layers)):
245+
# downsampling first
246+
this_segm = imresize(segm, (target_height // (2 ** (2+id_layer)), target_width // (2 ** (2+id_layer))),
247+
interp='nearest')
248+
# upsampling the downsampled segm
249+
this_segm_upsampled = imresize(this_segm, (standard_segm_h, standard_segm_w), interp='nearest')
250+
# for those labels that are still correct, we predict them at this layer
251+
this_segm_gt = this_segm_upsampled * (segm == this_segm_upsampled)
252+
segm_gt_list.append(torch.from_numpy(this_segm_gt.astype(np.int)).long()-1)
253+
# remove already assigned labels (keep unassigned labels)
254+
segm = segm * (this_segm_gt == 0)
255+
256+
segm_ori = torch.from_numpy(segm_ori.astype(np.int)).long()
238257

239-
batch_segms = torch.unsqueeze(segm, 0)
258+
batch_segms = torch.unsqueeze(segm_ori, 0)
240259

241260
batch_segms = batch_segms - 1 # label from -1 to 149
242261
output = dict()
243262
output['img_ori'] = img.copy()
244263
output['img_data'] = [x.contiguous() for x in img_resized_list]
245264
output['seg_label'] = batch_segms.contiguous()
265+
output['seg_gt_list'] = segm_gt_list
246266
output['info'] = this_record['fpath_img']
247267
return output
248268

eval_multipro.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def visualize_result(data, preds, args):
3939
cv2.imwrite(os.path.join(args.result,
4040
img_name.replace('.jpg', '.png')), im_vis)
4141

42+
def get_pred_map_via_gt(seg_gt_list, pred_list):
43+
pred_map = torch.zeros_like(pred_list[0]).type_as(pred_list[0])
44+
for this_seg_gt, this_pred in zip(seg_gt_list, pred_list):
45+
pred_map += (this_seg_gt >= 0).long().unsqueeze(0) * this_pred
46+
return pred_map
4247

4348
def evaluate(segmentation_module, loader, args, dev_id, result_queue):
4449

@@ -50,6 +55,7 @@ def evaluate(segmentation_module, loader, args, dev_id, result_queue):
5055
seg_label = as_numpy(batch_data['seg_label'][0])
5156

5257
img_resized_list = batch_data['img_data']
58+
seg_gt_list = batch_data['seg_gt_list']
5359

5460
with torch.no_grad():
5561
segSize = (seg_label.shape[0], seg_label.shape[1])
@@ -60,11 +66,12 @@ def evaluate(segmentation_module, loader, args, dev_id, result_queue):
6066
feed_dict['img_data'] = img
6167
del feed_dict['img_ori']
6268
del feed_dict['info']
69+
del feed_dict['seg_gt_list']
6370
feed_dict = async_copy_to(feed_dict, dev_id)
6471

6572
# forward pass
66-
pred_tmp = segmentation_module(feed_dict, segSize=segSize)
67-
pred = pred + pred_tmp.cpu() / len(args.imgSize)
73+
pred_list = segmentation_module(feed_dict, segSize=segSize)
74+
pred = pred + get_pred_map_via_gt(seg_gt_list, pred_list).cpu() / len(args.imgSize)
6875

6976
_, preds = torch.max(pred.data.cpu(), dim=1)
7077
preds = as_numpy(preds.squeeze(0))

0 commit comments

Comments
 (0)