Skip to content

Commit

Permalink
Add tta to HTC and Cascade RCNN (open-mmlab#1251)
Browse files Browse the repository at this point in the history
* add tta to HTC and Caccade RCNN

* format file with yapf

* fix import error with isort

* Update htc.py

* Update cascade_rcnn.py

* fix bug

* delete some redundant codes
  • Loading branch information
d0ng1ee authored and hellock committed Sep 14, 2019
1 parent d990625 commit 53a2c6f
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 8 deletions.
111 changes: 107 additions & 4 deletions mmdet/models/detectors/cascade_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import torch
import torch.nn as nn

from mmdet.core import (bbox2result, bbox2roi, build_assigner, build_sampler,
merge_aug_masks)
from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
build_sampler, merge_aug_bboxes, merge_aug_masks,
multiclass_nms)
from .. import builder
from ..registry import DETECTORS
from .base import BaseDetector
Expand Down Expand Up @@ -399,8 +400,110 @@ def simple_test(self, img, img_meta, proposals=None, rescale=False):

return results

def aug_test(self, img, img_meta, proposals=None, rescale=False):
raise NotImplementedError
def aug_test(self, imgs, img_metas, proposals=None, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
# recompute feats to save memory
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)

rcnn_test_cfg = self.test_cfg.rcnn
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(self.extract_feats(imgs), img_metas):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']

proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip)
# "ms" in variable names means multi-stage
ms_scores = []

rois = bbox2roi([proposals])
for i in range(self.num_stages):
bbox_roi_extractor = self.bbox_roi_extractor[i]
bbox_head = self.bbox_head[i]

bbox_feats = bbox_roi_extractor(
x[:len(bbox_roi_extractor.featmap_strides)], rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)

cls_score, bbox_pred = bbox_head(bbox_feats)
ms_scores.append(cls_score)

if i < self.num_stages - 1:
bbox_label = cls_score.argmax(dim=1)
rois = bbox_head.regress_by_class(rois, bbox_label,
bbox_pred, img_meta[0])

cls_score = sum(ms_scores) / float(len(ms_scores))
bboxes, scores = self.bbox_head[-1].get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)

# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)

bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head[-1].num_classes)

if self.with_mask:
if det_bboxes.shape[0] == 0:
segm_result = [[]
for _ in range(self.mask_head[-1].num_classes -
1)]
else:
aug_masks = []
aug_img_metas = []
for x, img_meta in zip(self.extract_feats(imgs), img_metas):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
for i in range(self.num_stages):
mask_feats = self.mask_roi_extractor[i](
x[:len(self.mask_roi_extractor[i].featmap_strides
)], mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
mask_pred = self.mask_head[i](mask_feats)
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
aug_img_metas.append(img_meta)
merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
self.test_cfg.rcnn)

ori_shape = img_metas[0][0]['ori_shape']
segm_result = self.mask_head[-1].get_seg_masks(
merged_masks,
det_bboxes,
det_labels,
rcnn_test_cfg,
ori_shape,
scale_factor=1.0,
rescale=False)
return bbox_result, segm_result
else:
return bbox_result

def show_result(self, data, result, **kwargs):
if self.with_mask:
Expand Down
128 changes: 124 additions & 4 deletions mmdet/models/detectors/htc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import torch.nn.functional as F

from mmdet.core import (bbox2result, bbox2roi, build_assigner, build_sampler,
merge_aug_masks)
from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
build_sampler, merge_aug_bboxes, merge_aug_masks,
multiclass_nms)
from .. import builder
from ..registry import DETECTORS
from .cascade_rcnn import CascadeRCNN
Expand Down Expand Up @@ -431,5 +432,124 @@ def simple_test(self, img, img_meta, proposals=None, rescale=False):

return results

def aug_test(self, img, img_meta, proposals=None, rescale=False):
raise NotImplementedError
def aug_test(self, imgs, img_metas, proposals=None, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
if self.with_semantic:
semantic_feats = [
self.semantic_head(feat)[1]
for feat in self.extract_feats(imgs)
]
else:
semantic_feats = [None] * len(img_metas)

# recompute feats to save memory
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)

rcnn_test_cfg = self.test_cfg.rcnn
aug_bboxes = []
aug_scores = []
for x, img_meta, semantic in zip(
self.extract_feats(imgs), img_metas, semantic_feats):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']

proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip)
# "ms" in variable names means multi-stage
ms_scores = []

rois = bbox2roi([proposals])
for i in range(self.num_stages):
bbox_head = self.bbox_head[i]
cls_score, bbox_pred = self._bbox_forward_test(
i, x, rois, semantic_feat=semantic)
ms_scores.append(cls_score)

if i < self.num_stages - 1:
bbox_label = cls_score.argmax(dim=1)
rois = bbox_head.regress_by_class(rois, bbox_label,
bbox_pred, img_meta[0])

cls_score = sum(ms_scores) / float(len(ms_scores))
bboxes, scores = self.bbox_head[-1].get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)

# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)

bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head[-1].num_classes)

if self.with_mask:
if det_bboxes.shape[0] == 0:
segm_result = [[]
for _ in range(self.mask_head[-1].num_classes -
1)]
else:
aug_masks = []
aug_img_metas = []
for x, img_meta, semantic in zip(
self.extract_feats(imgs), img_metas, semantic_feats):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor[-1](
x[:len(self.mask_roi_extractor[-1].featmap_strides)],
mask_rois)
if self.with_semantic:
semantic_feat = semantic
mask_semantic_feat = self.semantic_roi_extractor(
[semantic_feat], mask_rois)
if mask_semantic_feat.shape[-2:] != mask_feats.shape[
-2:]:
mask_semantic_feat = F.adaptive_avg_pool2d(
mask_semantic_feat, mask_feats.shape[-2:])
mask_feats += mask_semantic_feat
last_feat = None
for i in range(self.num_stages):
mask_head = self.mask_head[i]
if self.mask_info_flow:
mask_pred, last_feat = mask_head(
mask_feats, last_feat)
else:
mask_pred = mask_head(mask_feats)
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
aug_img_metas.append(img_meta)
merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
self.test_cfg.rcnn)

ori_shape = img_metas[0][0]['ori_shape']
segm_result = self.mask_head[-1].get_seg_masks(
merged_masks,
det_bboxes,
det_labels,
rcnn_test_cfg,
ori_shape,
scale_factor=1.0,
rescale=False)
return bbox_result, segm_result
else:
return bbox_result

0 comments on commit 53a2c6f

Please sign in to comment.