Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix for the case where there are no detections
Browse files Browse the repository at this point in the history
Use cv2.cvtColor instead of np to convert from BGR 2 RGB
Fix context in detector
  • Loading branch information
larroy committed Feb 13, 2018
1 parent 83f6279 commit 17ff889
Showing 1 changed file with 49 additions and 31 deletions.
80 changes: 49 additions & 31 deletions example/ssd/detect/detector.py
Expand Up @@ -15,12 +15,12 @@
# specific language governing permissions and limitations
# under the License.

from __future__ import print_function
import mxnet as mx
import numpy as np
from timeit import default_timer as timer
from dataset.testdb import TestDB
from dataset.iterator import DetIter
import logging

class Detector(object):
"""
Expand All @@ -43,6 +43,7 @@ class Detector(object):
ctx : mx.ctx
device to use, if None, use mx.cpu() as default context
"""
CLASS = 0
def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
batch_size=1, ctx=None):
self.ctx = ctx
Expand All @@ -51,7 +52,7 @@ def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
if symbol is None:
symbol = load_symbol
self.mod = mx.mod.Module(symbol, label_names=None, context=ctx)
self.mod = mx.mod.Module(symbol, label_names=None, context=self.ctx)
if not isinstance(data_shape, tuple):
data_shape = (data_shape, data_shape)
self.data_shape = data_shape
Expand Down Expand Up @@ -81,13 +82,9 @@ def detect(self, det_iter, show_timer=False):
detections = self.mod.predict(det_iter).asnumpy()
time_elapsed = timer() - start
if show_timer:
print("Detection time for {} images: {:.4f} sec".format(
logging.info("Detection time for {} images: {:.4f} sec".format(
num_images, time_elapsed))
result = []
for i in range(detections.shape[0]):
det = detections[i, :, :]
res = det[np.where(det[:, 0] >= 0)[0]]
result.append(res)
result = Detector.filter_positive_detections(detections)
return result

def im_detect(self, im_list, root_dir=None, extension=None, show_timer=False):
Expand Down Expand Up @@ -136,31 +133,52 @@ class names
height = img.shape[0]
width = img.shape[1]
colors = dict()
for i in range(dets.shape[0]):
cls_id = int(dets[i, 0])
if cls_id >= 0:
score = dets[i, 1]
if score > thresh:
if cls_id not in colors:
colors[cls_id] = (random.random(), random.random(), random.random())
xmin = int(dets[i, 2] * width)
ymin = int(dets[i, 3] * height)
xmax = int(dets[i, 4] * width)
ymax = int(dets[i, 5] * height)
rect = plt.Rectangle((xmin, ymin), xmax - xmin,
ymax - ymin, fill=False,
edgecolor=colors[cls_id],
linewidth=3.5)
plt.gca().add_patch(rect)
class_name = str(cls_id)
if classes and len(classes) > cls_id:
class_name = classes[cls_id]
plt.gca().text(xmin, ymin - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor=colors[cls_id], alpha=0.5),
for det in dets:
(klass, score, x0, y0, x1, y1) = det
if score < thresh:
continue
cls_id = int(klass)
if cls_id not in colors:
colors[cls_id] = (random.random(), random.random(), random.random())
xmin = int(x0 * width)
ymin = int(y0 * height)
xmax = int(x1 * width)
ymax = int(y1 * height)
rect = plt.Rectangle((xmin, ymin), xmax - xmin,
ymax - ymin, fill=False,
edgecolor=colors[cls_id],
linewidth=3.5)
plt.gca().add_patch(rect)
class_name = str(cls_id)
if classes and len(classes) > cls_id:
class_name = classes[cls_id]
plt.gca().text(xmin, ymin - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor=colors[cls_id], alpha=0.5),
fontsize=12, color='white')
plt.show()

@staticmethod
def filter_positive_detections(detections):
"""
First column (class id) is -1 for negative detections
:param detections:
:return:
"""
print(type(detections))
assert((type(detections) is mx.nd.NDArray) or (type(detections) is np.ndarray))
detections_per_image = []
# for each image
for i in range(detections.shape[0]):
result = []
det = detections[i, :, :]
for obj in det:
if obj[Detector.CLASS] >= 0:
result.append(obj)
detections_per_image.append(result)
logging.info("%d positive detections", len(result))
return detections_per_image

def detect_and_visualize(self, im_list, root_dir=None, extension=None,
classes=[], thresh=0.6, show_timer=False):
"""
Expand All @@ -187,5 +205,5 @@ def detect_and_visualize(self, im_list, root_dir=None, extension=None,
assert len(dets) == len(im_list)
for k, det in enumerate(dets):
img = cv2.imread(im_list[k])
img[:, :, (0, 1, 2)] = img[:, :, (2, 1, 0)]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
self.visualize_detection(img, det, classes, thresh)

0 comments on commit 17ff889

Please sign in to comment.