Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix for the case where there are no detections #9784

Merged
merged 1 commit into from
Apr 23, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 48 additions & 31 deletions example/ssd/detect/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# specific language governing permissions and limitations
# under the License.

from __future__ import print_function
import mxnet as mx
import numpy as np
from timeit import default_timer as timer
from dataset.testdb import TestDB
from dataset.iterator import DetIter
import logging

class Detector(object):
"""
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
if symbol is None:
symbol = load_symbol
self.mod = mx.mod.Module(symbol, label_names=None, context=ctx)
self.mod = mx.mod.Module(symbol, label_names=None, context=self.ctx)
if not isinstance(data_shape, tuple):
data_shape = (data_shape, data_shape)
self.data_shape = data_shape
Expand Down Expand Up @@ -81,13 +81,9 @@ def detect(self, det_iter, show_timer=False):
detections = self.mod.predict(det_iter).asnumpy()
time_elapsed = timer() - start
if show_timer:
print("Detection time for {} images: {:.4f} sec".format(
logging.info("Detection time for {} images: {:.4f} sec".format(
num_images, time_elapsed))
result = []
for i in range(detections.shape[0]):
det = detections[i, :, :]
res = det[np.where(det[:, 0] >= 0)[0]]
result.append(res)
result = Detector.filter_positive_detections(detections)
return result

def im_detect(self, im_list, root_dir=None, extension=None, show_timer=False):
Expand Down Expand Up @@ -136,31 +132,52 @@ class names
height = img.shape[0]
width = img.shape[1]
colors = dict()
for i in range(dets.shape[0]):
cls_id = int(dets[i, 0])
if cls_id >= 0:
score = dets[i, 1]
if score > thresh:
if cls_id not in colors:
colors[cls_id] = (random.random(), random.random(), random.random())
xmin = int(dets[i, 2] * width)
ymin = int(dets[i, 3] * height)
xmax = int(dets[i, 4] * width)
ymax = int(dets[i, 5] * height)
rect = plt.Rectangle((xmin, ymin), xmax - xmin,
ymax - ymin, fill=False,
edgecolor=colors[cls_id],
linewidth=3.5)
plt.gca().add_patch(rect)
class_name = str(cls_id)
if classes and len(classes) > cls_id:
class_name = classes[cls_id]
plt.gca().text(xmin, ymin - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor=colors[cls_id], alpha=0.5),
for det in dets:
(klass, score, x0, y0, x1, y1) = det
if score < thresh:
continue
cls_id = int(klass)
if cls_id not in colors:
colors[cls_id] = (random.random(), random.random(), random.random())
xmin = int(x0 * width)
ymin = int(y0 * height)
xmax = int(x1 * width)
ymax = int(y1 * height)
rect = plt.Rectangle((xmin, ymin), xmax - xmin,
ymax - ymin, fill=False,
edgecolor=colors[cls_id],
linewidth=3.5)
plt.gca().add_patch(rect)
class_name = str(cls_id)
if classes and len(classes) > cls_id:
class_name = classes[cls_id]
plt.gca().text(xmin, ymin - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor=colors[cls_id], alpha=0.5),
fontsize=12, color='white')
plt.show()

@staticmethod
def filter_positive_detections(detections):
"""
First column (class id) is -1 for negative detections
:param detections:
:return:
"""
class_idx = 0
assert(isinstance(detections, mx.nd.NDArray) or isinstance(detections, np.ndarray))
detections_per_image = []
# for each image
for i in range(detections.shape[0]):
result = []
det = detections[i, :, :]
for obj in det:
if obj[class_idx] >= 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about self.class_idx with default 0 in __init__

Copy link
Contributor Author

@larroy larroy Feb 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not possible, the method is static. That's why there was a class variable on the first place when the review started. Seems we are going in circles.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to use member function rather than static. Is there any specific reason we can't do it? Or it just serves as better self-explained code?
Anyways, I don't bother overcomplicating such a simple class, either way LGTM and I am just trying to settle it down quickly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter_positive_detections doesn't need any member data, it's a pure function, reflected by @staticmethod so one knows that is not going to mutate class state. Anyway, maybe is my pedantic defensive programming from other safer programming languages, in Python seems everyone codes however they want.

result.append(obj)
detections_per_image.append(result)
logging.info("%d positive detections", len(result))
return detections_per_image

def detect_and_visualize(self, im_list, root_dir=None, extension=None,
classes=[], thresh=0.6, show_timer=False):
"""
Expand All @@ -187,5 +204,5 @@ def detect_and_visualize(self, im_list, root_dir=None, extension=None,
assert len(dets) == len(im_list)
for k, det in enumerate(dets):
img = cv2.imread(im_list[k])
img[:, :, (0, 1, 2)] = img[:, :, (2, 1, 0)]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
self.visualize_detection(img, det, classes, thresh)