From b320fce29468ce96618a251d8492709de47fb46f Mon Sep 17 00:00:00 2001 From: tilt Date: Wed, 24 Jan 2018 10:34:10 +0100 Subject: [PATCH 1/2] Fix batch-wise inference --- layers/functions/detection.py | 29 ++++++++++------------------- ssd.py | 3 ++- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/layers/functions/detection.py b/layers/functions/detection.py index 22efa7c76..1f2fddc0b 100644 --- a/layers/functions/detection.py +++ b/layers/functions/detection.py @@ -1,8 +1,5 @@ import torch -import torch.nn as nn -import torch.backends.cudnn as cudnn from torch.autograd import Function -from torch.autograd import Variable from ..box_utils import decode, nms from data import v2 as cfg @@ -23,7 +20,6 @@ def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh): raise ValueError('nms_threshold must be non negative.') self.conf_thresh = conf_thresh self.variance = cfg['variance'] - self.output = torch.zeros(1, self.num_classes, self.top_k, 5) def forward(self, loc_data, conf_data, prior_data): """ @@ -37,21 +33,16 @@ def forward(self, loc_data, conf_data, prior_data): """ num = loc_data.size(0) # batch size num_priors = prior_data.size(0) - self.output.zero_() - if num == 1: - # size batch x num_classes x num_priors - conf_preds = conf_data.t().contiguous().unsqueeze(0) - else: - conf_preds = conf_data.view(num, num_priors, - self.num_classes).transpose(2, 1) - self.output.expand_(num, self.num_classes, self.top_k, 5) + output = torch.zeros(num, self.num_classes, self.top_k, 5) + conf_preds = conf_data.view(num, num_priors, + self.num_classes).transpose(2, 1) # Decode predictions into bboxes. for i in range(num): decoded_boxes = decode(loc_data[i], prior_data, self.variance) # For each class, perform nms conf_scores = conf_preds[i].clone() - num_det = 0 + for cl in range(1, self.num_classes): c_mask = conf_scores[cl].gt(self.conf_thresh) scores = conf_scores[cl][c_mask] @@ -61,11 +52,11 @@ def forward(self, loc_data, conf_data, prior_data): boxes = decoded_boxes[l_mask].view(-1, 4) # idx of highest scoring and non-overlapping boxes per class ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) - self.output[i, cl, :count] = \ + output[i, cl, :count] = \ torch.cat((scores[ids[:count]].unsqueeze(1), boxes[ids[:count]]), 1) - flt = self.output.view(-1, 5) - _, idx = flt[:, 0].sort(0) - _, rank = idx.sort(0) - flt[(rank >= self.top_k).unsqueeze(1).expand_as(flt)].fill_(0) - return self.output + flt = output.contiguous().view(num, -1, 5) + _, idx = flt[:, :, 0].sort(1) + _, rank = idx.sort(1) + flt[(rank >= self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) + return output diff --git a/ssd.py b/ssd.py index cd570064c..00583da18 100644 --- a/ssd.py +++ b/ssd.py @@ -97,7 +97,8 @@ def forward(self, x): if self.phase == "test": output = self.detect( loc.view(loc.size(0), -1, 4), # loc preds - self.softmax(conf.view(-1, self.num_classes)), # conf preds + self.softmax(conf.view(-1, self.num_classes)) \ + .view(conf.size(0), -1, self.num_classes), # conf preds self.priors.type(type(x.data)) # default boxes ) else: From 2ecb75f0abc90dfa14e5fa207cf4a94f17d176e3 Mon Sep 17 00:00:00 2001 From: tilt Date: Thu, 25 Jan 2018 10:29:24 +0100 Subject: [PATCH 2/2] Fix non-maximum suppression Suppress all but the top_k most confident predictions. Formerly only the least confident `top_k` were filtered out in the final "detect" step. --- layers/functions/detection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/layers/functions/detection.py b/layers/functions/detection.py index 1f2fddc0b..3bd49decb 100644 --- a/layers/functions/detection.py +++ b/layers/functions/detection.py @@ -56,7 +56,7 @@ def forward(self, loc_data, conf_data, prior_data): torch.cat((scores[ids[:count]].unsqueeze(1), boxes[ids[:count]]), 1) flt = output.contiguous().view(num, -1, 5) - _, idx = flt[:, :, 0].sort(1) + _, idx = flt[:, :, 0].sort(1, descending=True) _, rank = idx.sort(1) - flt[(rank >= self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) + flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) return output