Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ chmod +x demo_test.sh
```
This script downloads a trained model (ResNet50dilated + PPM_deepsup) and a test image, runs the test script, and saves predicted segmentation (.png) to the working directory.

2. To test on multiple images, you can simply do something as the following (```$PATH_IMG1, $PATH_IMG2, $PATH_IMG3```are your image paths):
2. To test on multiple images or a folder of images, you can simply do something as the following (```$PATH_IMG1, $PATH_IMG2, $PATH_IMG3```are your image paths):
```
python3 -u test.py \
--model_path $MODEL_PATH \
Expand Down
201 changes: 86 additions & 115 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,68 @@
import lib.utils.data as torchdata
import cv2
from torchvision import transforms
from scipy.misc import imread, imresize
import numpy as np

# Round x to the nearest multiple of p and x' >= x
def round2nearest_multiple(x, p):
return ((x - 1) // p + 1) * p

class TrainDataset(torchdata.Dataset):
def __init__(self, odgt, opt, max_sample=-1, batch_per_gpu=1):
self.root_dataset = opt.root_dataset
class BaseDataset(torchdata.Dataset):
def __init__(self, odgt, opt, **kwargs):
# parse options
self.imgSize = opt.imgSize
self.imgMaxSize = opt.imgMaxSize
self.random_flip = opt.random_flip
# max down sampling rate of network to avoid rounding during conv or pooling
self.padding_constant = opt.padding_constant
# down sampling rate of segm labe
self.segm_downsampling_rate = opt.segm_downsampling_rate
self.batch_per_gpu = batch_per_gpu

# classify images into two classes: 1. h > w and 2. h <= w
self.batch_record_list = [[], []]

# override dataset length when trainig with batch_per_gpu > 1
self.cur_idx = 0
# parse the input list
self.parse_input_list(odgt, **kwargs)

# mean and std
self.img_transform = transforms.Compose([
transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])
])
self.normalize = transforms.Normalize(
mean=[102.9801, 115.9465, 122.7717],
std=[1., 1., 1.])

self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
if isinstance(odgt, list):
self.list_sample = odgt
elif isinstance(odgt, str):
self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]

self.if_shuffled = False
if max_sample > 0:
self.list_sample = self.list_sample[0:max_sample]
if start_idx >= 0 and end_idx >= 0: # divide file list
self.list_sample = self.list_sample[start_idx:end_idx]

self.num_sample = len(self.list_sample)
assert self.num_sample > 0
print('# samples: {}'.format(self.num_sample))

def img_transform(self, img):
# image to float
img = img.astype(np.float32)
img = img.transpose((2, 0, 1))
img = self.normalize(torch.from_numpy(img.copy()))
return img

# Round x to the nearest multiple of p and x' >= x
def round2nearest_multiple(self, x, p):
return ((x - 1) // p + 1) * p


class TrainDataset(BaseDataset):
def __init__(self, odgt, opt, batch_per_gpu=1, **kwargs):
super(TrainDataset, self).__init__(odgt, opt, **kwargs)
self.root_dataset = opt.root_dataset
self.random_flip = opt.random_flip
# down sampling rate of segm labe
self.segm_downsampling_rate = opt.segm_downsampling_rate
self.batch_per_gpu = batch_per_gpu

# classify images into two classes: 1. h > w and 2. h <= w
self.batch_record_list = [[], []]

# override dataset length when trainig with batch_per_gpu > 1
self.cur_idx = 0
self.if_shuffled = False

def _get_sub_batch(self):
while True:
# get a sample record
Expand Down Expand Up @@ -88,60 +111,63 @@ def __getitem__(self, index):
batch_resized_size = np.zeros((self.batch_per_gpu, 2), np.int32)
for i in range(self.batch_per_gpu):
img_height, img_width = batch_records[i]['height'], batch_records[i]['width']
this_scale = min(this_short_size / min(img_height, img_width), \
self.imgMaxSize / max(img_height, img_width))
this_scale = min(
this_short_size / min(img_height, img_width), \
self.imgMaxSize / max(img_height, img_width))
img_resized_height, img_resized_width = img_height * this_scale, img_width * this_scale
batch_resized_size[i, :] = img_resized_height, img_resized_width
batch_resized_height = np.max(batch_resized_size[:, 0])
batch_resized_width = np.max(batch_resized_size[:, 1])

# Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w'
batch_resized_height = int(round2nearest_multiple(batch_resized_height, self.padding_constant))
batch_resized_width = int(round2nearest_multiple(batch_resized_width, self.padding_constant))
batch_resized_height = int(self.round2nearest_multiple(batch_resized_height, self.padding_constant))
batch_resized_width = int(self.round2nearest_multiple(batch_resized_width, self.padding_constant))

assert self.padding_constant >= self.segm_downsampling_rate,\
'padding constant must be equal or large than segm downsamping rate'
batch_images = torch.zeros(self.batch_per_gpu, 3, batch_resized_height, batch_resized_width)
batch_segms = torch.zeros(self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \
batch_resized_width // self.segm_downsampling_rate).long()
batch_segms = torch.zeros(
self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \
batch_resized_width // self.segm_downsampling_rate).long()

for i in range(self.batch_per_gpu):
this_record = batch_records[i]

# load image and label
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
img = imread(image_path, mode='RGB')
segm = imread(segm_path)
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)

assert(img.ndim == 3)
assert(segm.ndim == 2)
assert(img.shape[0] == segm.shape[0])
assert(img.shape[1] == segm.shape[1])

if self.random_flip == True:
if self.random_flip is True:
random_flip = np.random.choice([0, 1])
if random_flip == 1:
img = cv2.flip(img, 1)
segm = cv2.flip(segm, 1)

# note that each sample within a mini batch has different scale param
img = imresize(img, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='bilinear')
segm = imresize(segm, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='nearest')
img = cv2.resize(img, (batch_resized_size[i, 1], batch_resized_size[i, 0]), interpolation=cv2.INTER_LINEAR)
segm = cv2.resize(segm, (batch_resized_size[i, 1], batch_resized_size[i, 0]), interpolation=cv2.INTER_NEAREST)

# to avoid seg label misalignment
segm_rounded_height = round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate)
segm_rounded_width = round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate)
segm_rounded_height = self.round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate)
segm_rounded_width = self.round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate)
segm_rounded = np.zeros((segm_rounded_height, segm_rounded_width), dtype='uint8')
segm_rounded[:segm.shape[0], :segm.shape[1]] = segm

segm = imresize(segm_rounded, (segm_rounded.shape[0] // self.segm_downsampling_rate, \
segm_rounded.shape[1] // self.segm_downsampling_rate), \
interp='nearest')
# image to float
img = img.astype(np.float32)[:, :, ::-1] # RGB to BGR!!!
img = img.transpose((2, 0, 1))
img = self.img_transform(torch.from_numpy(img.copy()))
segm = cv2.resize(
segm_rounded,
(segm_rounded.shape[1] // self.segm_downsampling_rate, \
segm_rounded.shape[0] // self.segm_downsampling_rate), \
interpolation=cv2.INTER_NEAREST)

# image transform
img = self.img_transform(img)

batch_images[i][:, :img.shape[1], :img.shape[2]] = img
batch_segms[i][:segm.shape[0], :segm.shape[1]] = torch.from_numpy(segm.astype(np.int)).long()
Expand All @@ -157,66 +183,42 @@ def __len__(self):
#return self.num_sampleclass


class ValDataset(torchdata.Dataset):
def __init__(self, odgt, opt, max_sample=-1, start_idx=-1, end_idx=-1):
class ValDataset(BaseDataset):
def __init__(self, odgt, opt, **kwargs):
super(ValDataset, self).__init__(odgt, opt, **kwargs)
self.root_dataset = opt.root_dataset
self.imgSize = opt.imgSize
self.imgMaxSize = opt.imgMaxSize
# max down sampling rate of network to avoid rounding during conv or pooling
self.padding_constant = opt.padding_constant

# mean and std
self.img_transform = transforms.Compose([
transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])
])

self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]

if max_sample > 0:
self.list_sample = self.list_sample[0:max_sample]

if start_idx >= 0 and end_idx >= 0: # divide file list
self.list_sample = self.list_sample[start_idx:end_idx]

self.num_sample = len(self.list_sample)
assert self.num_sample > 0
print('# samples: {}'.format(self.num_sample))

def __getitem__(self, index):
this_record = self.list_sample[index]
# load image and label
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
img = imread(image_path, mode='RGB')
img = img[:, :, ::-1] # BGR to RGB!!!
segm = imread(segm_path)
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)

ori_height, ori_width, _ = img.shape

img_resized_list = []
for this_short_size in self.imgSize:
# calculate target height and width
scale = min(this_short_size / float(min(ori_height, ori_width)),
self.imgMaxSize / float(max(ori_height, ori_width)))
self.imgMaxSize / float(max(ori_height, ori_width)))
target_height, target_width = int(ori_height * scale), int(ori_width * scale)

# to avoid rounding in network
target_height = round2nearest_multiple(target_height, self.padding_constant)
target_width = round2nearest_multiple(target_width, self.padding_constant)
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
target_width = self.round2nearest_multiple(target_width, self.padding_constant)

# resize
img_resized = cv2.resize(img.copy(), (target_width, target_height))

# image to float
img_resized = img_resized.astype(np.float32)
img_resized = img_resized.transpose((2, 0, 1))
img_resized = self.img_transform(torch.from_numpy(img_resized))
# image transform
img_resized = self.img_transform(img_resized)

img_resized = torch.unsqueeze(img_resized, 0)
img_resized_list.append(img_resized)

segm = torch.from_numpy(segm.astype(np.int)).long()

batch_segms = torch.unsqueeze(segm, 0)

batch_segms = batch_segms - 1 # label from -1 to 149
Expand All @@ -231,71 +233,40 @@ def __len__(self):
return self.num_sample


class TestDataset(torchdata.Dataset):
def __init__(self, odgt, opt, max_sample=-1):
self.imgSize = opt.imgSize
self.imgMaxSize = opt.imgMaxSize
# max down sampling rate of network to avoid rounding during conv or pooling
self.padding_constant = opt.padding_constant
# down sampling rate of segm labe
self.segm_downsampling_rate = opt.segm_downsampling_rate

# mean and std
self.img_transform = transforms.Compose([
transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])
])

if isinstance(odgt, list):
self.list_sample = odgt
elif isinstance(odgt, str):
self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]

if max_sample > 0:
self.list_sample = self.list_sample[0:max_sample]
self.num_sample = len(self.list_sample)
assert self.num_sample > 0
print('# samples: {}'.format(self.num_sample))
class TestDataset(BaseDataset):
def __init__(self, odgt, opt, **kwargs):
super(TestDataset, self).__init__(odgt, opt, **kwargs)

def __getitem__(self, index):
this_record = self.list_sample[index]
# load image and label
image_path = this_record['fpath_img']
img = imread(image_path, mode='RGB')
img = img[:, :, ::-1] # BGR to RGB!!!
img = cv2.imread(image_path, cv2.IMREAD_COLOR)

ori_height, ori_width, _ = img.shape

img_resized_list = []
for this_short_size in self.imgSize:
# calculate target height and width
scale = min(this_short_size / float(min(ori_height, ori_width)),
self.imgMaxSize / float(max(ori_height, ori_width)))
self.imgMaxSize / float(max(ori_height, ori_width)))
target_height, target_width = int(ori_height * scale), int(ori_width * scale)

# to avoid rounding in network
target_height = round2nearest_multiple(target_height, self.padding_constant)
target_width = round2nearest_multiple(target_width, self.padding_constant)
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
target_width = self.round2nearest_multiple(target_width, self.padding_constant)

# resize
img_resized = cv2.resize(img.copy(), (target_width, target_height))

# image to float
img_resized = img_resized.astype(np.float32)
img_resized = img_resized.transpose((2, 0, 1))
img_resized = self.img_transform(torch.from_numpy(img_resized))

# image transform
img_resized = self.img_transform(img_resized)
img_resized = torch.unsqueeze(img_resized, 0)
img_resized_list.append(img_resized)

# segm = torch.from_numpy(segm.astype(np.int)).long()

# batch_segms = torch.unsqueeze(segm, 0)

# batch_segms = batch_segms - 1 # label from -1 to 149
output = dict()
output['img_ori'] = img.copy()
output['img_data'] = [x.contiguous() for x in img_resized_list]
# output['seg_label'] = batch_segms.contiguous()
output['info'] = this_record['fpath_img']
return output

Expand Down
16 changes: 9 additions & 7 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Our libs
from dataset import TestDataset
from models import ModelBuilder, SegmentationModule
from utils import colorEncode
from utils import colorEncode, find_recursive
from lib.nn import user_scattered_collate, async_copy_to
from lib.utils import as_numpy
import lib.utils.data as torchdata
Expand All @@ -24,11 +24,10 @@ def visualize_result(data, pred, args):
(img, info) = data

# prediction
pred_color = colorEncode(pred, colors)
pred_color = colorEncode(pred, colors).astype(np.uint8)

# aggregate images and save
im_vis = np.concatenate((img, pred_color),
axis=1).astype(np.uint8)
im_vis = np.concatenate((img, pred_color), axis=1)

img_name = info.split('/')[-1]
cv2.imwrite(os.path.join(args.result,
Expand Down Expand Up @@ -93,8 +92,11 @@ def main(args):
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)

# Dataset and Loader
# list_test = [{'fpath_img': args.test_img}]
list_test = [{'fpath_img': x} for x in args.test_imgs]
if len(args.test_imgs) == 1 and os.path.isdir(args.test_imgs[0]):
test_imgs = find_recursive(args.test_imgs[0])
else:
test_imgs = args.test_imgs
list_test = [{'fpath_img': x} for x in test_imgs]
dataset_test = TestDataset(
list_test, args, max_sample=args.num_val)
loader_test = torchdata.DataLoader(
Expand All @@ -120,7 +122,7 @@ def main(args):
parser = argparse.ArgumentParser()
# Path related arguments
parser.add_argument('--test_imgs', required=True, nargs='+', type=str,
help='a list of image paths that needs to be tested')
help='a list of image paths, or a directory name')
parser.add_argument('--model_path', required=True,
help='folder to model path')
parser.add_argument('--suffix', default='_epoch_20.pth',
Expand Down
Loading