Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# System libs
import os
import datetime
import time
import argparse
from distutils.version import LooseVersion
# Numerical libs
Expand All @@ -16,18 +16,19 @@
from lib.utils import as_numpy, mark_volatile
import lib.utils.data as torchdata
import cv2
from tqdm import tqdm

colors = loadmat('data/color150.mat')['colors']


def visualize_result(data, preds, args):
def visualize_result(data, pred, args):
(img, seg, info) = data

# segmentation
seg_color = colorEncode(seg, colors)

# prediction
pred_color = colorEncode(preds, colors)
pred_color = colorEncode(pred, colors)

# aggregate images and save
im_vis = np.concatenate((img, seg_color, pred_color),
Expand All @@ -42,19 +43,22 @@ def evaluate(segmentation_module, loader, args):
acc_meter = AverageMeter()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
time_meter = AverageMeter()

segmentation_module.eval()

for i, batch_data in enumerate(loader):
pbar = tqdm(total=len(loader))
for batch_data in loader:
# process data
batch_data = batch_data[0]
seg_label = as_numpy(batch_data['seg_label'][0])

img_resized_list = batch_data['img_data']

tic = time.time()
with torch.no_grad():
segSize = (seg_label.shape[0], seg_label.shape[1])
pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
scores = async_copy_to(scores, args.gpu_id)

for img in img_resized_list:
feed_dict = batch_data.copy()
Expand All @@ -64,35 +68,37 @@ def evaluate(segmentation_module, loader, args):
feed_dict = async_copy_to(feed_dict, args.gpu_id)

# forward pass
pred_tmp = segmentation_module(feed_dict, segSize=segSize)
pred = pred + pred_tmp.cpu() / len(args.imgSize)
scores_tmp = segmentation_module(feed_dict, segSize=segSize)
scores = scores + scores_tmp / len(args.imgSize)

_, pred = torch.max(scores, dim=1)
pred = as_numpy(pred.squeeze(0).cpu())

_, preds = torch.max(pred.data.cpu(), dim=1)
preds = as_numpy(preds.squeeze(0))
time_meter.update(time.time() - tic)

# calculate accuracy
acc, pix = accuracy(preds, seg_label)
intersection, union = intersectionAndUnion(preds, seg_label, args.num_class)
acc, pix = accuracy(pred, seg_label)
intersection, union = intersectionAndUnion(pred, seg_label, args.num_class)
acc_meter.update(acc, pix)
intersection_meter.update(intersection)
union_meter.update(union)
print('[{}] iter {}, accuracy: {}'
.format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
i, acc))

# visualization
if args.visualize:
visualize_result(
(batch_data['img_ori'], seg_label, batch_data['info']),
preds, args)
pred, args)

pbar.update(1)

# summary
iou = intersection_meter.sum / (union_meter.sum + 1e-10)
for i, _iou in enumerate(iou):
print('class [{}], IoU: {}'.format(i, _iou))

print('[Eval Summary]:')
print('Mean IoU: {:.4}, Accuracy: {:.2f}%'
.format(iou.mean(), acc_meter.average()*100))
print('Mean IoU: {:.4}, Accuracy: {:.2f}%, Inference Time: {:.4}s'
.format(iou.mean(), acc_meter.average()*100, time_meter.average()))


def main(args):
Expand Down
30 changes: 15 additions & 15 deletions eval_multipro.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# System libs
import os
import datetime
import argparse
from distutils.version import LooseVersion
from multiprocessing import Queue, Process
Expand All @@ -23,14 +22,14 @@
colors = loadmat('data/color150.mat')['colors']


def visualize_result(data, preds, args):
def visualize_result(data, pred, args):
(img, seg, info) = data

# segmentation
seg_color = colorEncode(seg, colors)

# prediction
pred_color = colorEncode(preds, colors)
pred_color = colorEncode(pred, colors)

# aggregate images and save
im_vis = np.concatenate((img, seg_color, pred_color),
Expand All @@ -42,19 +41,18 @@ def visualize_result(data, preds, args):


def evaluate(segmentation_module, loader, args, dev_id, result_queue):

segmentation_module.eval()

for i, batch_data in enumerate(loader):
for batch_data in loader:
# process data
batch_data = batch_data[0]
seg_label = as_numpy(batch_data['seg_label'][0])

img_resized_list = batch_data['img_data']

with torch.no_grad():
segSize = (seg_label.shape[0], seg_label.shape[1])
pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
scores = async_copy_to(scores, dev_id)

for img in img_resized_list:
feed_dict = batch_data.copy()
Expand All @@ -64,22 +62,22 @@ def evaluate(segmentation_module, loader, args, dev_id, result_queue):
feed_dict = async_copy_to(feed_dict, dev_id)

# forward pass
pred_tmp = segmentation_module(feed_dict, segSize=segSize)
pred = pred + pred_tmp.cpu() / len(args.imgSize)
scores_tmp = segmentation_module(feed_dict, segSize=segSize)
scores = scores + scores_tmp / len(args.imgSize)

_, preds = torch.max(pred.data.cpu(), dim=1)
preds = as_numpy(preds.squeeze(0))
_, pred = torch.max(scores, dim=1)
pred = as_numpy(pred.squeeze(0).cpu())

# calculate accuracy and SEND THEM TO MASTER
acc, pix = accuracy(preds, seg_label)
intersection, union = intersectionAndUnion(preds, seg_label, args.num_class)
acc, pix = accuracy(pred, seg_label)
intersection, union = intersectionAndUnion(pred, seg_label, args.num_class)
result_queue.put_nowait((acc, pix, intersection, union))

# visualization
if args.visualize:
visualize_result(
(batch_data['img_ori'], seg_label, batch_data['info']),
preds, args)
pred, args)


def worker(args, dev_id, start_idx, end_idx, result_queue):
Expand Down Expand Up @@ -118,6 +116,7 @@ def worker(args, dev_id, start_idx, end_idx, result_queue):
# Main loop
evaluate(segmentation_module, loader_val, args, dev_id, result_queue)


def main(args):
# Parse device ids
default_dev, *parallel_dev = parse_devices(args.devices)
Expand Down Expand Up @@ -145,7 +144,7 @@ def main(args):
start_idx = dev_id * nr_files_per_dev
end_idx = min(start_idx + nr_files_per_dev, nr_files)
proc = Process(target=worker, args=(args, dev_id, start_idx, end_idx, result_queue))
print('process:%d, start_idx:%d, end_idx:%d' % (dev_id, start_idx, end_idx))
print('process:{}, start_idx:{}, end_idx:{}'.format(dev_id, start_idx, end_idx))
proc.start()
procs.append(proc)

Expand All @@ -164,6 +163,7 @@ def main(args):
for p in procs:
p.join()

# summary
iou = intersection_meter.sum / (union_meter.sum + 1e-10)
for i, _iou in enumerate(iou):
print('class [{}], IoU: {}'.format(i, _iou))
Expand Down
7 changes: 3 additions & 4 deletions lib/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
import torch.cuda as cuda
import torch.nn as nn
import torch
from torch.autograd import Variable
import collections
from torch.nn.parallel._functions import Gather


__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']


def async_copy_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
obj = Variable(obj)
if isinstance(obj, Variable):
v = obj.cuda(dev, non_blocking=True)
if main_stream is not None:
v.data.record_stream(main_stream)
Expand All @@ -32,7 +31,7 @@ def dict_gather(outputs, target_device, dim=0):
"""
def gather_map(outputs):
out = outputs[0]
if isinstance(out, Variable):
if torch.is_tensor(out):
# MJY(20180330) HACK:: force nr_dims > 0
if out.dim() == 0:
outputs = [o.unsqueeze(0) for o in outputs]
Expand Down
26 changes: 14 additions & 12 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
from lib.utils import as_numpy, mark_volatile
import lib.utils.data as torchdata
import cv2
from tqdm import tqdm

colors = loadmat('data/color150.mat')['colors']

def visualize_result(data, preds, args):
colors = loadmat('data/color150.mat')['colors']

def visualize_result(data, pred, args):
(img, info) = data

# prediction
pred_color = colorEncode(preds, colors)
pred_color = colorEncode(pred, colors)

# aggregate images and save
im_vis = np.concatenate((img, pred_color),
Expand All @@ -37,16 +39,17 @@ def visualize_result(data, preds, args):
def test(segmentation_module, loader, args):
segmentation_module.eval()

for i, batch_data in enumerate(loader):
pbar = tqdm(total=len(loader))
for batch_data in loader:
# process data
batch_data = batch_data[0]
segSize = (batch_data['img_ori'].shape[0],
batch_data['img_ori'].shape[1])

img_resized_list = batch_data['img_data']

with torch.no_grad():
pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
scores = async_copy_to(scores, args.gpu_id)

for img in img_resized_list:
feed_dict = batch_data.copy()
Expand All @@ -57,18 +60,17 @@ def test(segmentation_module, loader, args):

# forward pass
pred_tmp = segmentation_module(feed_dict, segSize=segSize)
pred = pred + pred_tmp.cpu() / len(args.imgSize)
scores = scores + pred_tmp / len(args.imgSize)

_, preds = torch.max(pred, dim=1)
preds = as_numpy(preds.squeeze(0))
_, pred = torch.max(scores, dim=1)
pred = as_numpy(pred.squeeze(0).cpu())

# visualization
visualize_result(
(batch_data['img_ori'], batch_data['info']),
preds, args)
pred, args)

print('[{}] iter {}'
.format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i))
pbar.update(1)


def main(args):
Expand Down