Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release the implementation of adding BorderAlign Module to the RetinaNet? #15

Closed
lingyunwu14 opened this issue Sep 22, 2020 · 7 comments
Labels
help wanted Extra attention is needed

Comments

@lingyunwu14
Copy link

lingyunwu14 commented Sep 22, 2020

Thanks for your excellent work and nice open-source code.

"Our BorderDet can be easily integrated with the many popular object detectors, e.g. RetinaNet and FPN.To prove the generalization of the BorderDet, we first add the proposed border alignment module to the RetinaNet. For a fair comparison, without modifying any setting of RetinaNet, we directly select the one with the highest score from the nine prediction boxes of each pixel to refine."

I tried to reproduce it as described in the paper, but the model has not converged well. so, can you please release the implementation of adding BorderAlign Module to the RetinaNet? thanks.

@Maycbj
Copy link
Member

Maycbj commented Sep 24, 2020

I am very sorry that all of our experiments were done in another codebase. But we have only reproduced the primary experiment from the previous codebase without BorderAlign on RetinaNet.
BorderAlign on RetinaNet is very similar to the BorderAlign on FCOS. Each location on feature map with nine anchors and nine offset values, they are combined into the coarse bounding boxes. We use the coarse bounding boxes to get the border feature. The feature map in front of the BorderAlign is (N,95C,H,W). 5 is the number of the feature(origin, left, right, top, bottom). 9 is the number of the anchor boxes.
I believe it is easy to implement the experiment. If you have any other questions, please reply to me directly.

@lingyunwu14
Copy link
Author

lingyunwu14 commented Sep 25, 2020

Thanks for your reply. There may be something wrong with what I understand.

"we directly select the one with the highest score from the nine prediction boxes of each pixel to refine." I understand it to mean that the feature map in front of the BorderAlign of RetinaNet is the same as that of FCOS. Therefore, I did ArgMax on the coarse classification score first, got the highest one of the 9 anchors as the coarse bounding boxes, and kept the same operation as FCOS in the Border Align Module.

The following is my implementation of BorderHead, please take a look and tell me what is wrong. Thank you very much.

class RetinaNetBorderHead(torch.nn.Module):
    def __init__(self, cfg, in_channels):
        super(RetinaNetBorderHead, self).__init__()
        self.num_classes = 80
        self.num_anchors = 9
        self.fpn_strides = [8, 16, 32, 64, 128]

        cls_tower = []
        bbox_tower = []
        for i in range(4):
            cls_tower.append(
                nn.Conv2d(in_channels, in_channels,
                    kernel_size=3, stride=1, padding=1))
            cls_tower.append(nn.ReLU())
            bbox_tower.append(
                nn.Conv2d(in_channels, in_channels,
                    kernel_size=3, stride=1, padding=1))
            bbox_tower.append(nn.ReLU())
        self.add_module('cls_tower', nn.Sequential(*cls_tower))
        self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
        self.cls_logits = nn.Conv2d(
            in_channels, self.num_anchors * self.num_classes, 
            kernel_size=3, stride=1, padding=1)
        self.bbox_pred = nn.Conv2d(
            in_channels,  self.num_anchors * 4, 
            kernel_size=3, stride=1, padding=1)

        # Border Align Module (BorderBranch is the same as your implementation)
        self.add_module("border_cls_subnet", BorderBranch(in_channels, 256))
        self.add_module("border_bbox_subnet", BorderBranch(in_channels, 128))
        self.border_cls_score = nn.Conv2d(
            in_channels, self.num_classes, kernel_size=1, stride=1)
        self.border_bbox_pred = nn.Conv2d(
            in_channels, 4, kernel_size=1, stride=1)

        # Initialization
        for modules in [self.cls_tower, self.bbox_tower, 
                        self.cls_logits, self.bbox_pred,
                        self.border_cls_subnet, self.border_bbox_subnet,
                        self.border_cls_score, self.border_bbox_pred]:
            for l in modules.modules():
                if isinstance(l, nn.Conv2d):
                    torch.nn.init.normal_(l.weight, mean=0, std=0.01)
                    torch.nn.init.constant_(l.bias, 0)
        bias_value = -math.log((1 - 0.01) / 0.01)
        torch.nn.init.constant_(self.cls_logits.bias, bias_value)
        torch.nn.init.constant_(self.border_cls_score.bias, bias_value)

    def forward(self, x, anchors, box_coder):
        anchors = list(zip(*anchors))
        logits = []
        bbox_reg = []
        # Border
        border_logits = []
        border_bbox_reg = []
        pre_bbox = []

        for level, (feature, anchor) in enumerate(zip(x, anchors)):
            # feature: [N, 256, H, W]
            cls_tower = self.cls_tower(feature) # [N, 256, H, W]
            bbox_tower = self.bbox_tower(feature) # [N, 256, H, W]
            cls_logits = self.cls_logits(cls_tower) # [N, 9*80, H, W]
            bbox_pred = self.bbox_pred(bbox_tower) # [N, 9*4, H, W])
            logits.append(cls_logits)
            bbox_reg.append(bbox_pred)

            # Border
            N, C, H, W = feature.shape
            pre_scores = cls_logits.clone().detach()
            pre_deltas = bbox_pred.clone().detach()
            with torch.no_grad():
                each_level_pre_bbox = []
                # for each image
                for each_anchor, each_pre_scores, each_pre_deltas in zip(anchor, pre_scores, pre_deltas):
                    # each_anchor num: H*W*9, each_pre_scores: [9*80, H, W]), each_pre_deltas: [9*4, H, W]

                    # get argmax index on 9 anchors with each_pre_scores
                    each_pre_scores = each_pre_scores.permute(1, 2, 0).reshape(-1, self.num_anchors, self.num_classes)
                    # [H*W, 9, 30]
                    each_pre_scores = each_pre_scores.max(dim=2)[0]
                    # [H*W, 9]
                    argmax_score_index = each_pre_scores.argmax(dim=1)
                    # [H*W]

                    # get each_pre_deltas with argmax_score_index
                    each_pre_deltas = each_pre_deltas.permute(1, 2, 0).reshape(-1, self.num_anchors, 4)
                    # [H*W, 9, 4]
                    each_pre_deltas = each_pre_deltas[torch.arange(each_pre_deltas.size(0)), argmax_score_index, :]
                    # [H*W, 4], [center_x, center_y, w, h]

                    #get each_anchor with argmax_score_index
                    each_anchor_bbox = each_anchor.bbox.clone().detach()
                    each_anchor_bbox = each_anchor_bbox.reshape(-1, self.num_anchors, 4)
                    #[H*W, 9, 4]
                    each_anchor_bbox = each_anchor_bbox[torch.arange(each_anchor_bbox.size(0)), argmax_score_index, :]
                    #[H*W, 4], [x1, y1, x2, y2]

                    #decode [center_x, center_y, w, h] to [x1,y1,x2,y2] with anchor
                    each_img_pre_bbox = box_coder.decode(each_pre_deltas, each_anchor_bbox)
                    # [H*W, 4], [x1, y1, x2, y2]
                    each_level_pre_bbox.append(each_img_pre_bbox)

                each_level_pre_bbox = torch.stack(each_level_pre_bbox)
                # [N, H*W, 4], [x1, y1, x2, y2]
                pre_bbox.append(each_level_pre_bbox)

                #align to feature map scale
                align_boxes, wh = self.compute_border(each_level_pre_bbox, level, H, W)
                # align_boxes: [N, H*W, 4], [x1, y1, x2, y2]
                # wh: [N, H*W, 2]), [w, h]

            border_cls_conv = self.border_cls_subnet(cls_tower, align_boxes, wh)
            # [N, 256, H, W]
            border_cls_logits = self.border_cls_score(border_cls_conv)
            # [N, 80, H, W]
            border_logits.append(border_cls_logits)

            border_reg_conv = self.border_bbox_subnet(bbox_tower, align_boxes, wh)
            # [N, 256, H, W]
            border_bbox_pred = self.border_bbox_pred(border_reg_conv)
            # [N, 4, H, W]
            border_bbox_reg.append(border_bbox_pred)

        if self.training:
            pre_bbox = torch.cat(pre_bbox, dim=1)
        return logits, bbox_reg, border_logits, border_bbox_reg, pre_bbox

@Maycbj
Copy link
Member

Maycbj commented Sep 28, 2020

retinanet.res50.800size.1x.tsd.zip
This is the implementation of the BD-Retinanet, maybe it will help you.

@FateScript FateScript added the help wanted Extra attention is needed label Sep 29, 2020
@lingyunwu14
Copy link
Author

retinanet.res50.800size.1x.tsd.zip
This is the implementation of the BD-Retinanet, maybe it will help you.

Thanks for sharing. I still have three questions:

  1. Is ExtremePointAlignSimple in retinanet.py equal to BorderAlign?
  2. In the forward function of RetinaNetHead, sec_reg_conv has different dimensions in training and inference (i.e. [N,256,H,W] for traning and [N,9*256,H,W] for inference), but self.sec_bbox_pred is only suitable for training. sec_cls_conv has the same problem.
            sec_reg_conv = self.sec_bbox_tower(shared_bbox_tower, align_boxes, wh)
            sec_bbox_reg = self.sec_bbox_pred(sec_reg_conv.reshape(N, -1, int(H * W)))
            sec_bbox_reg_list.append(sec_bbox_reg.reshape(N, -1, H, W))
  1. Mode=='fir' is always used in the “prepare_targets” function of loss.py. Is it wrong? I think the condition of mode!='fir' will be used in the second stage.
    sec_labels, sec_reg_targets = self.prepare_targets(pre_boxes_list, targets)

@Maycbj
Copy link
Member

Maycbj commented Oct 11, 2020

  1. They are the same.
  2. sec_reg_conv has the same dimensions in training and inference. The feature dimensions are both [N, HW9, 256] during training and during testing.
  3. They are two hyper-parameters of box_coder, we have not made ablation experiments for these ablation studies. And I suggest you can make ablation experiments for them.

@lingyunwu14
Copy link
Author

lingyunwu14 commented Oct 12, 2020

  1. They are the same.
  2. sec_reg_conv has the same dimensions in training and inference. The feature dimensions are both [N, H_W_9, 256] during training and during testing.
  3. They are two hyper-parameters of box_coder, we have not made ablation experiments for these ablation studies. And I suggest you can make ablation experiments for them.

Thank you very much but I'm a bit confused about question 2.

if the dimension of sec_reg_conv is [N, HW9, 256], how to reshape it to (N, -1, int(H * W))?
sec_bbox_reg = self.sec_bbox_pred(sec_reg_conv.reshape(N, -1, int(H * W)))
Firstly, sec_reg_conv is the align_conv returned in function forward of BorderBranch, the dimension of align_conv is reshaped to [N, A*C, H, W].
align_conv = align_conv.reshape(N, C, H, W, A).permute(0, 4, 1, 2, 3).reshape(N, -1, H, W)
Secondly, A=1 for training and A=9 for inference, because of the anchors_per_fm randomly takes a value in the A dimension in the training stage.
anchors_per_fm = anchors_per_fm[range(num_box), rand_index]
This is my understanding and I'm looking forward to your corrections.

This was referenced Jan 24, 2021
@Maycbj
Copy link
Member

Maycbj commented Feb 3, 2021

I'm sorry that I reply to you so late.
We have tried to use 9 anchors to train our method based on RetinaNet, which is similar to the anchor-free detectors with BorderAlign. But it is computationally intensive. So we tried different strategies to reduce the amount of calculation.

  1. choose the highest score to train
  2. choose the highest IoU to train
  3. randomly choose an anchor to train
    Strategy 3) yields the best performance in our experiments, owing to we supervise them equally to overcome overfitting.
    Besides, I suggest that BorderDet based on Anchor-Free detector may be a better choice for you to get better performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants