From 7c5ea376b885cac19f8f08d533254422dc5697d7 Mon Sep 17 00:00:00 2001 From: Hang Zhao Date: Sat, 23 Feb 2019 14:09:18 -0500 Subject: [PATCH 1/2] abstract dataset class, replace scipy with cv2 --- dataset.py | 201 +++++++++++++++++++++++------------------------------ 1 file changed, 86 insertions(+), 115 deletions(-) diff --git a/dataset.py b/dataset.py index 14592529..d771564b 100644 --- a/dataset.py +++ b/dataset.py @@ -4,45 +4,68 @@ import lib.utils.data as torchdata import cv2 from torchvision import transforms -from scipy.misc import imread, imresize import numpy as np -# Round x to the nearest multiple of p and x' >= x -def round2nearest_multiple(x, p): - return ((x - 1) // p + 1) * p -class TrainDataset(torchdata.Dataset): - def __init__(self, odgt, opt, max_sample=-1, batch_per_gpu=1): - self.root_dataset = opt.root_dataset +class BaseDataset(torchdata.Dataset): + def __init__(self, odgt, opt, **kwargs): + # parse options self.imgSize = opt.imgSize self.imgMaxSize = opt.imgMaxSize - self.random_flip = opt.random_flip # max down sampling rate of network to avoid rounding during conv or pooling self.padding_constant = opt.padding_constant - # down sampling rate of segm labe - self.segm_downsampling_rate = opt.segm_downsampling_rate - self.batch_per_gpu = batch_per_gpu - # classify images into two classes: 1. h > w and 2. h <= w - self.batch_record_list = [[], []] - - # override dataset length when trainig with batch_per_gpu > 1 - self.cur_idx = 0 + # parse the input list + self.parse_input_list(odgt, **kwargs) # mean and std - self.img_transform = transforms.Compose([ - transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.]) - ]) + self.normalize = transforms.Normalize( + mean=[102.9801, 115.9465, 122.7717], + std=[1., 1., 1.]) - self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] + def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1): + if isinstance(odgt, list): + self.list_sample = odgt + elif isinstance(odgt, str): + self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] - self.if_shuffled = False if max_sample > 0: self.list_sample = self.list_sample[0:max_sample] + if start_idx >= 0 and end_idx >= 0: # divide file list + self.list_sample = self.list_sample[start_idx:end_idx] + self.num_sample = len(self.list_sample) assert self.num_sample > 0 print('# samples: {}'.format(self.num_sample)) + def img_transform(self, img): + # image to float + img = img.astype(np.float32) + img = img.transpose((2, 0, 1)) + img = self.normalize(torch.from_numpy(img.copy())) + return img + + # Round x to the nearest multiple of p and x' >= x + def round2nearest_multiple(self, x, p): + return ((x - 1) // p + 1) * p + + +class TrainDataset(BaseDataset): + def __init__(self, odgt, opt, batch_per_gpu=1, **kwargs): + super(TrainDataset, self).__init__(odgt, opt, **kwargs) + self.root_dataset = opt.root_dataset + self.random_flip = opt.random_flip + # down sampling rate of segm labe + self.segm_downsampling_rate = opt.segm_downsampling_rate + self.batch_per_gpu = batch_per_gpu + + # classify images into two classes: 1. h > w and 2. h <= w + self.batch_record_list = [[], []] + + # override dataset length when trainig with batch_per_gpu > 1 + self.cur_idx = 0 + self.if_shuffled = False + def _get_sub_batch(self): while True: # get a sample record @@ -88,22 +111,24 @@ def __getitem__(self, index): batch_resized_size = np.zeros((self.batch_per_gpu, 2), np.int32) for i in range(self.batch_per_gpu): img_height, img_width = batch_records[i]['height'], batch_records[i]['width'] - this_scale = min(this_short_size / min(img_height, img_width), \ - self.imgMaxSize / max(img_height, img_width)) + this_scale = min( + this_short_size / min(img_height, img_width), \ + self.imgMaxSize / max(img_height, img_width)) img_resized_height, img_resized_width = img_height * this_scale, img_width * this_scale batch_resized_size[i, :] = img_resized_height, img_resized_width batch_resized_height = np.max(batch_resized_size[:, 0]) batch_resized_width = np.max(batch_resized_size[:, 1]) # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w' - batch_resized_height = int(round2nearest_multiple(batch_resized_height, self.padding_constant)) - batch_resized_width = int(round2nearest_multiple(batch_resized_width, self.padding_constant)) + batch_resized_height = int(self.round2nearest_multiple(batch_resized_height, self.padding_constant)) + batch_resized_width = int(self.round2nearest_multiple(batch_resized_width, self.padding_constant)) assert self.padding_constant >= self.segm_downsampling_rate,\ 'padding constant must be equal or large than segm downsamping rate' batch_images = torch.zeros(self.batch_per_gpu, 3, batch_resized_height, batch_resized_width) - batch_segms = torch.zeros(self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \ - batch_resized_width // self.segm_downsampling_rate).long() + batch_segms = torch.zeros( + self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \ + batch_resized_width // self.segm_downsampling_rate).long() for i in range(self.batch_per_gpu): this_record = batch_records[i] @@ -111,37 +136,38 @@ def __getitem__(self, index): # load image and label image_path = os.path.join(self.root_dataset, this_record['fpath_img']) segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) - img = imread(image_path, mode='RGB') - segm = imread(segm_path) + img = cv2.imread(image_path, cv2.IMREAD_COLOR) + segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE) assert(img.ndim == 3) assert(segm.ndim == 2) assert(img.shape[0] == segm.shape[0]) assert(img.shape[1] == segm.shape[1]) - if self.random_flip == True: + if self.random_flip is True: random_flip = np.random.choice([0, 1]) if random_flip == 1: img = cv2.flip(img, 1) segm = cv2.flip(segm, 1) # note that each sample within a mini batch has different scale param - img = imresize(img, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='bilinear') - segm = imresize(segm, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='nearest') + img = cv2.resize(img, (batch_resized_size[i, 1], batch_resized_size[i, 0]), interpolation=cv2.INTER_LINEAR) + segm = cv2.resize(segm, (batch_resized_size[i, 1], batch_resized_size[i, 0]), interpolation=cv2.INTER_NEAREST) # to avoid seg label misalignment - segm_rounded_height = round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate) - segm_rounded_width = round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate) + segm_rounded_height = self.round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate) + segm_rounded_width = self.round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate) segm_rounded = np.zeros((segm_rounded_height, segm_rounded_width), dtype='uint8') segm_rounded[:segm.shape[0], :segm.shape[1]] = segm - segm = imresize(segm_rounded, (segm_rounded.shape[0] // self.segm_downsampling_rate, \ - segm_rounded.shape[1] // self.segm_downsampling_rate), \ - interp='nearest') - # image to float - img = img.astype(np.float32)[:, :, ::-1] # RGB to BGR!!! - img = img.transpose((2, 0, 1)) - img = self.img_transform(torch.from_numpy(img.copy())) + segm = cv2.resize( + segm_rounded, + (segm_rounded.shape[1] // self.segm_downsampling_rate, \ + segm_rounded.shape[0] // self.segm_downsampling_rate), \ + interpolation=cv2.INTER_NEAREST) + + # image transform + img = self.img_transform(img) batch_images[i][:, :img.shape[1], :img.shape[2]] = img batch_segms[i][:segm.shape[0], :segm.shape[1]] = torch.from_numpy(segm.astype(np.int)).long() @@ -157,39 +183,18 @@ def __len__(self): #return self.num_sampleclass -class ValDataset(torchdata.Dataset): - def __init__(self, odgt, opt, max_sample=-1, start_idx=-1, end_idx=-1): +class ValDataset(BaseDataset): + def __init__(self, odgt, opt, **kwargs): + super(ValDataset, self).__init__(odgt, opt, **kwargs) self.root_dataset = opt.root_dataset - self.imgSize = opt.imgSize - self.imgMaxSize = opt.imgMaxSize - # max down sampling rate of network to avoid rounding during conv or pooling - self.padding_constant = opt.padding_constant - - # mean and std - self.img_transform = transforms.Compose([ - transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.]) - ]) - - self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] - - if max_sample > 0: - self.list_sample = self.list_sample[0:max_sample] - - if start_idx >= 0 and end_idx >= 0: # divide file list - self.list_sample = self.list_sample[start_idx:end_idx] - - self.num_sample = len(self.list_sample) - assert self.num_sample > 0 - print('# samples: {}'.format(self.num_sample)) def __getitem__(self, index): this_record = self.list_sample[index] # load image and label image_path = os.path.join(self.root_dataset, this_record['fpath_img']) segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) - img = imread(image_path, mode='RGB') - img = img[:, :, ::-1] # BGR to RGB!!! - segm = imread(segm_path) + img = cv2.imread(image_path, cv2.IMREAD_COLOR) + segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE) ori_height, ori_width, _ = img.shape @@ -197,26 +202,23 @@ def __getitem__(self, index): for this_short_size in self.imgSize: # calculate target height and width scale = min(this_short_size / float(min(ori_height, ori_width)), - self.imgMaxSize / float(max(ori_height, ori_width))) + self.imgMaxSize / float(max(ori_height, ori_width))) target_height, target_width = int(ori_height * scale), int(ori_width * scale) # to avoid rounding in network - target_height = round2nearest_multiple(target_height, self.padding_constant) - target_width = round2nearest_multiple(target_width, self.padding_constant) + target_height = self.round2nearest_multiple(target_height, self.padding_constant) + target_width = self.round2nearest_multiple(target_width, self.padding_constant) # resize img_resized = cv2.resize(img.copy(), (target_width, target_height)) - # image to float - img_resized = img_resized.astype(np.float32) - img_resized = img_resized.transpose((2, 0, 1)) - img_resized = self.img_transform(torch.from_numpy(img_resized)) + # image transform + img_resized = self.img_transform(img_resized) img_resized = torch.unsqueeze(img_resized, 0) img_resized_list.append(img_resized) segm = torch.from_numpy(segm.astype(np.int)).long() - batch_segms = torch.unsqueeze(segm, 0) batch_segms = batch_segms - 1 # label from -1 to 149 @@ -231,37 +233,15 @@ def __len__(self): return self.num_sample -class TestDataset(torchdata.Dataset): - def __init__(self, odgt, opt, max_sample=-1): - self.imgSize = opt.imgSize - self.imgMaxSize = opt.imgMaxSize - # max down sampling rate of network to avoid rounding during conv or pooling - self.padding_constant = opt.padding_constant - # down sampling rate of segm labe - self.segm_downsampling_rate = opt.segm_downsampling_rate - - # mean and std - self.img_transform = transforms.Compose([ - transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.]) - ]) - - if isinstance(odgt, list): - self.list_sample = odgt - elif isinstance(odgt, str): - self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] - - if max_sample > 0: - self.list_sample = self.list_sample[0:max_sample] - self.num_sample = len(self.list_sample) - assert self.num_sample > 0 - print('# samples: {}'.format(self.num_sample)) +class TestDataset(BaseDataset): + def __init__(self, odgt, opt, **kwargs): + super(TestDataset, self).__init__(odgt, opt, **kwargs) def __getitem__(self, index): this_record = self.list_sample[index] # load image and label image_path = this_record['fpath_img'] - img = imread(image_path, mode='RGB') - img = img[:, :, ::-1] # BGR to RGB!!! + img = cv2.imread(image_path, cv2.IMREAD_COLOR) ori_height, ori_width, _ = img.shape @@ -269,33 +249,24 @@ def __getitem__(self, index): for this_short_size in self.imgSize: # calculate target height and width scale = min(this_short_size / float(min(ori_height, ori_width)), - self.imgMaxSize / float(max(ori_height, ori_width))) + self.imgMaxSize / float(max(ori_height, ori_width))) target_height, target_width = int(ori_height * scale), int(ori_width * scale) # to avoid rounding in network - target_height = round2nearest_multiple(target_height, self.padding_constant) - target_width = round2nearest_multiple(target_width, self.padding_constant) + target_height = self.round2nearest_multiple(target_height, self.padding_constant) + target_width = self.round2nearest_multiple(target_width, self.padding_constant) # resize img_resized = cv2.resize(img.copy(), (target_width, target_height)) - # image to float - img_resized = img_resized.astype(np.float32) - img_resized = img_resized.transpose((2, 0, 1)) - img_resized = self.img_transform(torch.from_numpy(img_resized)) - + # image transform + img_resized = self.img_transform(img_resized) img_resized = torch.unsqueeze(img_resized, 0) img_resized_list.append(img_resized) - # segm = torch.from_numpy(segm.astype(np.int)).long() - - # batch_segms = torch.unsqueeze(segm, 0) - - # batch_segms = batch_segms - 1 # label from -1 to 149 output = dict() output['img_ori'] = img.copy() output['img_data'] = [x.contiguous() for x in img_resized_list] - # output['seg_label'] = batch_segms.contiguous() output['info'] = this_record['fpath_img'] return output From ce93582c8a8d8329e35a3a84d95b2508bba62efb Mon Sep 17 00:00:00 2001 From: Hang Zhao Date: Tue, 23 Apr 2019 17:13:21 -0400 Subject: [PATCH 2/2] enable testing on a folder of images --- README.md | 2 +- test.py | 16 +++++++++------- utils.py | 12 +++++++++++- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index af1e5310..587c2437 100644 --- a/README.md +++ b/README.md @@ -164,7 +164,7 @@ chmod +x demo_test.sh ``` This script downloads a trained model (ResNet50dilated + PPM_deepsup) and a test image, runs the test script, and saves predicted segmentation (.png) to the working directory. -2. To test on multiple images, you can simply do something as the following (```$PATH_IMG1, $PATH_IMG2, $PATH_IMG3```are your image paths): +2. To test on multiple images or a folder of images, you can simply do something as the following (```$PATH_IMG1, $PATH_IMG2, $PATH_IMG3```are your image paths): ``` python3 -u test.py \ --model_path $MODEL_PATH \ diff --git a/test.py b/test.py index 6ddf8f14..aec9b62e 100644 --- a/test.py +++ b/test.py @@ -10,7 +10,7 @@ # Our libs from dataset import TestDataset from models import ModelBuilder, SegmentationModule -from utils import colorEncode +from utils import colorEncode, find_recursive from lib.nn import user_scattered_collate, async_copy_to from lib.utils import as_numpy import lib.utils.data as torchdata @@ -24,11 +24,10 @@ def visualize_result(data, pred, args): (img, info) = data # prediction - pred_color = colorEncode(pred, colors) + pred_color = colorEncode(pred, colors).astype(np.uint8) # aggregate images and save - im_vis = np.concatenate((img, pred_color), - axis=1).astype(np.uint8) + im_vis = np.concatenate((img, pred_color), axis=1) img_name = info.split('/')[-1] cv2.imwrite(os.path.join(args.result, @@ -93,8 +92,11 @@ def main(args): segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) # Dataset and Loader - # list_test = [{'fpath_img': args.test_img}] - list_test = [{'fpath_img': x} for x in args.test_imgs] + if len(args.test_imgs) == 1 and os.path.isdir(args.test_imgs[0]): + test_imgs = find_recursive(args.test_imgs[0]) + else: + test_imgs = args.test_imgs + list_test = [{'fpath_img': x} for x in test_imgs] dataset_test = TestDataset( list_test, args, max_sample=args.num_val) loader_test = torchdata.DataLoader( @@ -120,7 +122,7 @@ def main(args): parser = argparse.ArgumentParser() # Path related arguments parser.add_argument('--test_imgs', required=True, nargs='+', type=str, - help='a list of image paths that needs to be tested') + help='a list of image paths, or a directory name') parser.add_argument('--model_path', required=True, help='folder to model path') parser.add_argument('--suffix', default='_epoch_20.pth', diff --git a/utils.py b/utils.py index 96e20aec..09bfc64e 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,16 @@ -import numpy as np +import os import re import functools +import fnmatch +import numpy as np + + +def find_recursive(root_dir, ext='.jpg'): + files = [] + for root, dirnames, filenames in os.walk(root_dir): + for filename in fnmatch.filter(filenames, '*' + ext): + files.append(os.path.join(root, filename)) + return files class AverageMeter(object):