diff --git a/eval.py b/eval.py index bfeed134..bba76173 100644 --- a/eval.py +++ b/eval.py @@ -1,6 +1,6 @@ # System libs import os -import datetime +import time import argparse from distutils.version import LooseVersion # Numerical libs @@ -16,18 +16,19 @@ from lib.utils import as_numpy, mark_volatile import lib.utils.data as torchdata import cv2 +from tqdm import tqdm colors = loadmat('data/color150.mat')['colors'] -def visualize_result(data, preds, args): +def visualize_result(data, pred, args): (img, seg, info) = data # segmentation seg_color = colorEncode(seg, colors) # prediction - pred_color = colorEncode(preds, colors) + pred_color = colorEncode(pred, colors) # aggregate images and save im_vis = np.concatenate((img, seg_color, pred_color), @@ -42,19 +43,22 @@ def evaluate(segmentation_module, loader, args): acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() + time_meter = AverageMeter() segmentation_module.eval() - for i, batch_data in enumerate(loader): + pbar = tqdm(total=len(loader)) + for batch_data in loader: # process data batch_data = batch_data[0] seg_label = as_numpy(batch_data['seg_label'][0]) - img_resized_list = batch_data['img_data'] + tic = time.time() with torch.no_grad(): segSize = (seg_label.shape[0], seg_label.shape[1]) - pred = torch.zeros(1, args.num_class, segSize[0], segSize[1]) + scores = torch.zeros(1, args.num_class, segSize[0], segSize[1]) + scores = async_copy_to(scores, args.gpu_id) for img in img_resized_list: feed_dict = batch_data.copy() @@ -64,35 +68,37 @@ def evaluate(segmentation_module, loader, args): feed_dict = async_copy_to(feed_dict, args.gpu_id) # forward pass - pred_tmp = segmentation_module(feed_dict, segSize=segSize) - pred = pred + pred_tmp.cpu() / len(args.imgSize) + scores_tmp = segmentation_module(feed_dict, segSize=segSize) + scores = scores + scores_tmp / len(args.imgSize) + + _, pred = torch.max(scores, dim=1) + pred = as_numpy(pred.squeeze(0).cpu()) - _, preds = torch.max(pred.data.cpu(), dim=1) - preds = as_numpy(preds.squeeze(0)) + time_meter.update(time.time() - tic) # calculate accuracy - acc, pix = accuracy(preds, seg_label) - intersection, union = intersectionAndUnion(preds, seg_label, args.num_class) + acc, pix = accuracy(pred, seg_label) + intersection, union = intersectionAndUnion(pred, seg_label, args.num_class) acc_meter.update(acc, pix) intersection_meter.update(intersection) union_meter.update(union) - print('[{}] iter {}, accuracy: {}' - .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - i, acc)) # visualization if args.visualize: visualize_result( (batch_data['img_ori'], seg_label, batch_data['info']), - preds, args) + pred, args) + + pbar.update(1) + # summary iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {}'.format(i, _iou)) print('[Eval Summary]:') - print('Mean IoU: {:.4}, Accuracy: {:.2f}%' - .format(iou.mean(), acc_meter.average()*100)) + print('Mean IoU: {:.4}, Accuracy: {:.2f}%, Inference Time: {:.4}s' + .format(iou.mean(), acc_meter.average()*100, time_meter.average())) def main(args): diff --git a/eval_multipro.py b/eval_multipro.py index 52bd0a79..d2b4d789 100644 --- a/eval_multipro.py +++ b/eval_multipro.py @@ -1,6 +1,5 @@ # System libs import os -import datetime import argparse from distutils.version import LooseVersion from multiprocessing import Queue, Process @@ -23,14 +22,14 @@ colors = loadmat('data/color150.mat')['colors'] -def visualize_result(data, preds, args): +def visualize_result(data, pred, args): (img, seg, info) = data # segmentation seg_color = colorEncode(seg, colors) # prediction - pred_color = colorEncode(preds, colors) + pred_color = colorEncode(pred, colors) # aggregate images and save im_vis = np.concatenate((img, seg_color, pred_color), @@ -42,19 +41,18 @@ def visualize_result(data, preds, args): def evaluate(segmentation_module, loader, args, dev_id, result_queue): - segmentation_module.eval() - for i, batch_data in enumerate(loader): + for batch_data in loader: # process data batch_data = batch_data[0] seg_label = as_numpy(batch_data['seg_label'][0]) - img_resized_list = batch_data['img_data'] with torch.no_grad(): segSize = (seg_label.shape[0], seg_label.shape[1]) - pred = torch.zeros(1, args.num_class, segSize[0], segSize[1]) + scores = torch.zeros(1, args.num_class, segSize[0], segSize[1]) + scores = async_copy_to(scores, dev_id) for img in img_resized_list: feed_dict = batch_data.copy() @@ -64,22 +62,22 @@ def evaluate(segmentation_module, loader, args, dev_id, result_queue): feed_dict = async_copy_to(feed_dict, dev_id) # forward pass - pred_tmp = segmentation_module(feed_dict, segSize=segSize) - pred = pred + pred_tmp.cpu() / len(args.imgSize) + scores_tmp = segmentation_module(feed_dict, segSize=segSize) + scores = scores + scores_tmp / len(args.imgSize) - _, preds = torch.max(pred.data.cpu(), dim=1) - preds = as_numpy(preds.squeeze(0)) + _, pred = torch.max(scores, dim=1) + pred = as_numpy(pred.squeeze(0).cpu()) # calculate accuracy and SEND THEM TO MASTER - acc, pix = accuracy(preds, seg_label) - intersection, union = intersectionAndUnion(preds, seg_label, args.num_class) + acc, pix = accuracy(pred, seg_label) + intersection, union = intersectionAndUnion(pred, seg_label, args.num_class) result_queue.put_nowait((acc, pix, intersection, union)) # visualization if args.visualize: visualize_result( (batch_data['img_ori'], seg_label, batch_data['info']), - preds, args) + pred, args) def worker(args, dev_id, start_idx, end_idx, result_queue): @@ -118,6 +116,7 @@ def worker(args, dev_id, start_idx, end_idx, result_queue): # Main loop evaluate(segmentation_module, loader_val, args, dev_id, result_queue) + def main(args): # Parse device ids default_dev, *parallel_dev = parse_devices(args.devices) @@ -145,7 +144,7 @@ def main(args): start_idx = dev_id * nr_files_per_dev end_idx = min(start_idx + nr_files_per_dev, nr_files) proc = Process(target=worker, args=(args, dev_id, start_idx, end_idx, result_queue)) - print('process:%d, start_idx:%d, end_idx:%d' % (dev_id, start_idx, end_idx)) + print('process:{}, start_idx:{}, end_idx:{}'.format(dev_id, start_idx, end_idx)) proc.start() procs.append(proc) @@ -164,6 +163,7 @@ def main(args): for p in procs: p.join() + # summary iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {}'.format(i, _iou)) diff --git a/lib/nn/parallel/data_parallel.py b/lib/nn/parallel/data_parallel.py index 8e9b72b8..376fc038 100644 --- a/lib/nn/parallel/data_parallel.py +++ b/lib/nn/parallel/data_parallel.py @@ -3,16 +3,15 @@ import torch.cuda as cuda import torch.nn as nn import torch -from torch.autograd import Variable import collections from torch.nn.parallel._functions import Gather + __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] + def async_copy_to(obj, dev, main_stream=None): if torch.is_tensor(obj): - obj = Variable(obj) - if isinstance(obj, Variable): v = obj.cuda(dev, non_blocking=True) if main_stream is not None: v.data.record_stream(main_stream) @@ -32,7 +31,7 @@ def dict_gather(outputs, target_device, dim=0): """ def gather_map(outputs): out = outputs[0] - if isinstance(out, Variable): + if torch.is_tensor(out): # MJY(20180330) HACK:: force nr_dims > 0 if out.dim() == 0: outputs = [o.unsqueeze(0) for o in outputs] diff --git a/test.py b/test.py index 2608baa3..b9e4e42e 100644 --- a/test.py +++ b/test.py @@ -16,14 +16,16 @@ from lib.utils import as_numpy, mark_volatile import lib.utils.data as torchdata import cv2 +from tqdm import tqdm +colors = loadmat('data/color150.mat')['colors'] -def visualize_result(data, preds, args): - colors = loadmat('data/color150.mat')['colors'] + +def visualize_result(data, pred, args): (img, info) = data # prediction - pred_color = colorEncode(preds, colors) + pred_color = colorEncode(pred, colors) # aggregate images and save im_vis = np.concatenate((img, pred_color), @@ -37,16 +39,17 @@ def visualize_result(data, preds, args): def test(segmentation_module, loader, args): segmentation_module.eval() - for i, batch_data in enumerate(loader): + pbar = tqdm(total=len(loader)) + for batch_data in loader: # process data batch_data = batch_data[0] segSize = (batch_data['img_ori'].shape[0], batch_data['img_ori'].shape[1]) - img_resized_list = batch_data['img_data'] with torch.no_grad(): - pred = torch.zeros(1, args.num_class, segSize[0], segSize[1]) + scores = torch.zeros(1, args.num_class, segSize[0], segSize[1]) + scores = async_copy_to(scores, args.gpu_id) for img in img_resized_list: feed_dict = batch_data.copy() @@ -57,18 +60,17 @@ def test(segmentation_module, loader, args): # forward pass pred_tmp = segmentation_module(feed_dict, segSize=segSize) - pred = pred + pred_tmp.cpu() / len(args.imgSize) + scores = scores + pred_tmp / len(args.imgSize) - _, preds = torch.max(pred, dim=1) - preds = as_numpy(preds.squeeze(0)) + _, pred = torch.max(scores, dim=1) + pred = as_numpy(pred.squeeze(0).cpu()) # visualization visualize_result( (batch_data['img_ori'], batch_data['info']), - preds, args) + pred, args) - print('[{}] iter {}' - .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i)) + pbar.update(1) def main(args):