Skip to content

Commit

Permalink
feat(demo): add class agnostic nms for demo and numpy based postproce…
Browse files Browse the repository at this point in the history
…ssing (#588)

feat(demo): add class agnostic nms for demo and numpy based postprocessing (#588)

Co-authored-by: Feng Wang <wangfeng19950315@163.com>
  • Loading branch information
Joker316701882 and FateScript committed Aug 26, 2021
1 parent 04d511b commit e579d81
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
3 changes: 2 additions & 1 deletion tools/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def inference(self, img):
if self.decoder is not None:
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(
outputs, self.num_classes, self.confthre, self.nmsthre
outputs, self.num_classes, self.confthre,
self.nmsthre, class_agnostic=True
)
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info
Expand Down
22 changes: 15 additions & 7 deletions yolox/utils/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def filter_box(output, scale_range):
return output[keep]


def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
Expand All @@ -53,12 +53,20 @@ def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
if not detections.size(0):
continue

nms_out_index = torchvision.ops.batched_nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
detections[:, 6],
nms_thre,
)
if class_agnostic:
nms_out_index = torchvision.ops.nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
nms_thre,
)
else:
nms_out_index = torchvision.ops.batched_nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
detections[:, 6],
nms_thre,
)

detections = detections[nms_out_index]
if output[i] is None:
output[i] = detections
Expand Down
30 changes: 29 additions & 1 deletion yolox/utils/demo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,17 @@ def nms(boxes, scores, nms_thr):
return keep


def multiclass_nms(boxes, scores, nms_thr, score_thr):
def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True):
"""Multiclass NMS implemented in Numpy"""
if class_agnostic:
nms_method = multiclass_nms_class_agnostic
else:
nms_method = multiclass_nms_class_aware
return nms_method(boxes, scores, nms_thr, score_thr)


def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr):
"""Multiclass NMS implemented in Numpy. Class-aware version."""
final_dets = []
num_classes = scores.shape[1]
for cls_ind in range(num_classes):
Expand All @@ -68,6 +77,25 @@ def multiclass_nms(boxes, scores, nms_thr, score_thr):
return np.concatenate(final_dets, 0)


def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr):
"""Multiclass NMS implemented in Numpy. Class-agnostic version."""
cls_inds = scores.argmax(1)
cls_scores = scores[np.arange(len(cls_inds)), cls_inds]

valid_score_mask = cls_scores > score_thr
if valid_score_mask.sum() == 0:
return None
valid_scores = cls_scores[valid_score_mask]
valid_boxes = boxes[valid_score_mask]
valid_cls_inds = cls_inds[valid_score_mask]
keep = nms(valid_boxes, valid_scores, nms_thr)
if keep:
dets = np.concatenate(
[valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1
)
return dets


def demo_postprocess(outputs, img_size, p6=False):

grids = []
Expand Down

0 comments on commit e579d81

Please sign in to comment.