diff --git a/example/ssd/detect/detector.py b/example/ssd/detect/detector.py index 3a7f89fcadd1..2a64b5e5e9eb 100644 --- a/example/ssd/detect/detector.py +++ b/example/ssd/detect/detector.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. -from __future__ import print_function import mxnet as mx import numpy as np from timeit import default_timer as timer from dataset.testdb import TestDB from dataset.iterator import DetIter +import logging class Detector(object): """ @@ -43,6 +43,7 @@ class Detector(object): ctx : mx.ctx device to use, if None, use mx.cpu() as default context """ + CLASS = 0 def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \ batch_size=1, ctx=None): self.ctx = ctx @@ -51,7 +52,7 @@ def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \ load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) if symbol is None: symbol = load_symbol - self.mod = mx.mod.Module(symbol, label_names=None, context=ctx) + self.mod = mx.mod.Module(symbol, label_names=None, context=self.ctx) if not isinstance(data_shape, tuple): data_shape = (data_shape, data_shape) self.data_shape = data_shape @@ -81,13 +82,9 @@ def detect(self, det_iter, show_timer=False): detections = self.mod.predict(det_iter).asnumpy() time_elapsed = timer() - start if show_timer: - print("Detection time for {} images: {:.4f} sec".format( + logging.info("Detection time for {} images: {:.4f} sec".format( num_images, time_elapsed)) - result = [] - for i in range(detections.shape[0]): - det = detections[i, :, :] - res = det[np.where(det[:, 0] >= 0)[0]] - result.append(res) + result = Detector.filter_positive_detections(detections) return result def im_detect(self, im_list, root_dir=None, extension=None, show_timer=False): @@ -136,31 +133,52 @@ class names height = img.shape[0] width = img.shape[1] colors = dict() - for i in range(dets.shape[0]): - cls_id = int(dets[i, 0]) - if cls_id >= 0: - score = dets[i, 1] - if score > thresh: - if cls_id not in colors: - colors[cls_id] = (random.random(), random.random(), random.random()) - xmin = int(dets[i, 2] * width) - ymin = int(dets[i, 3] * height) - xmax = int(dets[i, 4] * width) - ymax = int(dets[i, 5] * height) - rect = plt.Rectangle((xmin, ymin), xmax - xmin, - ymax - ymin, fill=False, - edgecolor=colors[cls_id], - linewidth=3.5) - plt.gca().add_patch(rect) - class_name = str(cls_id) - if classes and len(classes) > cls_id: - class_name = classes[cls_id] - plt.gca().text(xmin, ymin - 2, - '{:s} {:.3f}'.format(class_name, score), - bbox=dict(facecolor=colors[cls_id], alpha=0.5), + for det in dets: + (klass, score, x0, y0, x1, y1) = det + if score < thresh: + continue + cls_id = int(klass) + if cls_id not in colors: + colors[cls_id] = (random.random(), random.random(), random.random()) + xmin = int(x0 * width) + ymin = int(y0 * height) + xmax = int(x1 * width) + ymax = int(y1 * height) + rect = plt.Rectangle((xmin, ymin), xmax - xmin, + ymax - ymin, fill=False, + edgecolor=colors[cls_id], + linewidth=3.5) + plt.gca().add_patch(rect) + class_name = str(cls_id) + if classes and len(classes) > cls_id: + class_name = classes[cls_id] + plt.gca().text(xmin, ymin - 2, + '{:s} {:.3f}'.format(class_name, score), + bbox=dict(facecolor=colors[cls_id], alpha=0.5), fontsize=12, color='white') plt.show() + @staticmethod + def filter_positive_detections(detections): + """ + First column (class id) is -1 for negative detections + :param detections: + :return: + """ + print(type(detections)) + assert((type(detections) is mx.nd.NDArray) or (type(detections) is np.ndarray)) + detections_per_image = [] + # for each image + for i in range(detections.shape[0]): + result = [] + det = detections[i, :, :] + for obj in det: + if obj[Detector.CLASS] >= 0: + result.append(obj) + detections_per_image.append(result) + logging.info("%d positive detections", len(result)) + return detections_per_image + def detect_and_visualize(self, im_list, root_dir=None, extension=None, classes=[], thresh=0.6, show_timer=False): """ @@ -187,5 +205,5 @@ def detect_and_visualize(self, im_list, root_dir=None, extension=None, assert len(dets) == len(im_list) for k, det in enumerate(dets): img = cv2.imread(im_list[k]) - img[:, :, (0, 1, 2)] = img[:, :, (2, 1, 0)] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) self.visualize_detection(img, det, classes, thresh)