Skip to content

Commit

Permalink
Support Group-DETR (#84)
Browse files Browse the repository at this point in the history
* support group detr

* fix matcher bug

* refine readme

* refine license

Co-authored-by: ntianhe ren <rentianhe@dgx061.scc.idea>
  • Loading branch information
rentainhe and ntianhe ren committed Sep 28, 2022
1 parent e000f09 commit 068bf99
Show file tree
Hide file tree
Showing 12 changed files with 1,229 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Results and models are available in [model zoo](https://detrex.readthedocs.io/en
- [x] [DN-DETR (CVPR'2022 Oral)](./projects/dn_detr/)
- [x] [DN-Deformable-DETR (CVPR'2022 Oral)](./projects/dn_deformable_detr/)
- [x] [DINO (ArXiv'2022)](./projects/dino/)
- [x] [Group-DETR (ArXiv' 2022)](./projects/group_detr/)

Please see [projects](./projects/) for the details about projects that are built based on detrex.

Expand Down
3 changes: 2 additions & 1 deletion projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ Here are projects that are built on detrex which show you use detrex as a librar
- [Conditional DETR for Fast Training Convergence](./conditional_detr/)
- [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](./dab_detr/)
- [DN-DETR: Accelerate DETR Training by Introducing Query DeNoising](./dn_detr/)
- [DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection](./dino)
- [DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection](./dino)
- [Group DETR: Fast DETR Training with Group-Wise One-to-Many Assignment](./group_detr/)
38 changes: 38 additions & 0 deletions projects/group_detr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
## Group DETR: Fast DETR Training with Group-Wise One-to-Many Assignment

Chen, Qiang and Chen, Xiaokang and Wang, Jian and Feng, Haocheng and Han, Junyu and Ding, Errui and Zeng, Gang and Wang, Jingdong

[[`arXiv`](https://arxiv.org/abs/2207.13085)] [[`BibTeX`](#citing-conditional-detr)]

<div align="center">
<img src="./assets/group_detr_arch.png"/>
</div><br/>

**Note**: This is the implementation of `Conditional DETR + Group DETR`

## Training
All configs can be trained with:
```bash
cd detrex
python tools/train_net.py --config-file projects/group_detr/configs/path/to/config.py --num-gpus 8
```
By default, we use 8 GPUs with total batch size as 16 for training.

## Evaluation
Model evaluation can be done as follows:
```bash
cd detrex
python tools/train_net.py --config-file projects/group_detr/configs/path/to/config.py --eval-only train.init_checkpoint=/path/to/model_checkpoint
```

## Citing Group-DETR
If you find our work helpful for your research, please consider citing the following BibTeX entry.

```BibTex
@article{chen2022group,
title={Group DETR: Fast DETR Training with Group-Wise One-to-Many Assignment},
author={Chen, Qiang and Chen, Xiaokang and Wang, Jian and Feng, Haocheng and Han, Junyu and Ding, Errui and Zeng, Gang and Wang, Jingdong},
journal={arXiv preprint arXiv:2207.13085},
year={2022}
}
```
Binary file added projects/group_detr/assets/group_detr_arch.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 22 additions & 0 deletions projects/group_detr/configs/group_detr_r50_50ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from detrex.config import get_config
from .models.group_detr_r50 import model

dataloader = get_config("common/data/coco_detr.py").dataloader
optimizer = get_config("common/optim.py").AdamW
lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_50ep
train = get_config("common/train.py").train

# modify training config
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
train.output_dir = "./output/group_detr_r50_50ep"
train.max_iter = 375000
train.clip_grad.enabled = True
train.clip_grad.params.max_norm = 0.1
train.clip_grad.params.norm_type = 2

# modify optimizer config
optimizer.weight_decay = 1e-4
optimizer.params.lr_factor_func = lambda module_name: 0.1 if "backbone" in module_name else 1

# modify dataloader config
dataloader.train.num_workers = 16
94 changes: 94 additions & 0 deletions projects/group_detr/configs/models/group_detr_r50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch.nn as nn

from detrex.layers import PositionEmbeddingSine
from detrex.modeling.backbone import ResNet, BasicStem

from detectron2.config import LazyCall as L

from projects.group_detr.modeling import (
GroupDETR,
GroupDetrTransformer,
GroupDetrTransformerDecoder,
GroupDetrTransformerEncoder,
GroupHungarianMatcher,
GroupSetCriterion,
)


model = L(GroupDETR)(
backbone=L(ResNet)(
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
stages=L(ResNet.make_default_stages)(
depth=50,
stride_in_1x1=False,
norm="FrozenBN",
),
out_features=["res2", "res3", "res4", "res5"],
freeze_at=1,
),
in_features=["res5"], # only use last level feature in Conditional-DETR
in_channels=2048,
position_embedding=L(PositionEmbeddingSine)(
num_pos_feats=128,
temperature=10000,
normalize=True,
),
transformer=L(GroupDetrTransformer)(
encoder=L(GroupDetrTransformerEncoder)(
embed_dim=256,
num_heads=8,
attn_dropout=0.1,
feedforward_dim=2048,
ffn_dropout=0.1,
activation=L(nn.ReLU)(),
num_layers=6,
post_norm=False,
),
decoder=L(GroupDetrTransformerDecoder)(
embed_dim=256,
num_heads=8,
attn_dropout=0.1,
feedforward_dim=2048,
ffn_dropout=0.1,
activation=L(nn.ReLU)(),
num_layers=6,
group_nums="${...group_nums}",
post_norm=True,
return_intermediate=True,
),
),
embed_dim=256,
num_classes=80,
num_queries=300,
criterion=L(GroupSetCriterion)(
num_classes=80,
matcher=L(GroupHungarianMatcher)(
cost_class=2.0,
cost_bbox=5.0,
cost_giou=2.0,
),
weight_dict={
"loss_class": 2.0,
"loss_bbox": 5.0,
"loss_giou": 2.0,
},
group_nums="${..group_nums}",
alpha=0.25,
gamma=2.0,
),
aux_loss=True,
group_nums=11,
pixel_mean=[123.675, 116.280, 103.530],
pixel_std=[58.395, 57.120, 57.375],
select_box_nums_for_evaluation=300,
device="cuda",
)

# set aux loss weight dict
if model.aux_loss:
weight_dict = model.criterion.weight_dict
aux_weight_dict = {}
for i in range(model.transformer.decoder.num_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
model.criterion.weight_dict = weight_dict
9 changes: 9 additions & 0 deletions projects/group_detr/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .group_detr import GroupDETR
from .group_detr_transformer import (
GroupDetrTransformerEncoder,
GroupDetrTransformerDecoder,
GroupDetrTransformer,
)
from .attention import GroupConditionalSelfAttention
from .group_criterion import GroupSetCriterion
from .group_matcher import GroupHungarianMatcher
184 changes: 184 additions & 0 deletions projects/group_detr/modeling/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import torch
import torch.nn as nn


class GroupConditionalSelfAttention(nn.Module):
"""Conditional Self-Attention Module used in Group-DETR
`Conditional DETR for Fast Training Convergence.
<https://arxiv.org/pdf/2108.06152.pdf>`_
Args:
embed_dim (int): The embedding dimension for attention.
num_heads (int): The number of attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `MultiheadAttention`.
Default: 0.0.
batch_first (bool): if `True`, then the input and output tensor will be
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
"""

def __init__(
self,
embed_dim,
num_heads,
attn_drop=0.0,
proj_drop=0.0,
group_nums=11,
batch_first=False,
**kwargs,
):
super(GroupConditionalSelfAttention, self).__init__()
self.query_content_proj = nn.Linear(embed_dim, embed_dim)
self.query_pos_proj = nn.Linear(embed_dim, embed_dim)
self.key_content_proj = nn.Linear(embed_dim, embed_dim)
self.key_pos_proj = nn.Linear(embed_dim, embed_dim)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
self.group_nums = group_nums
self.num_heads = num_heads
self.embed_dim = embed_dim
head_dim = embed_dim // num_heads
self.scale = head_dim**-0.5
self.batch_first = batch_first

def forward(
self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None,
**kwargs,
):
"""Forward function for `ConditionalSelfAttention`
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)` if self.batch_first is False,
else `(bs, num_query, embed_dim)`
key (torch.Tensor): Key embeddings with shape
`(num_key, bs, embed_dim)` if self.batch_first is False,
else `(bs, num_key, embed_dim)`
value (torch.Tensor): Value embeddings with the same shape as `key`.
Same in `torch.nn.MultiheadAttention.forward`. Default: None.
If None, the `key` will be used.
identity (torch.Tensor): The tensor, with the same shape as `query``,
which will be used for identity addition. Default: None.
If None, `query` will be used.
query_pos (torch.Tensor): The position embedding for query, with the
same shape as `query`. Default: None.
key_pos (torch.Tensor): The position embedding for key. Default: None.
If None, and `query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`.
attn_mask (torch.Tensor): ByteTensor mask with shape `(num_query, num_key)`.
Same as `torch.nn.MultiheadAttention.forward`. Default: None.
key_padding_mask (torch.Tensor): ByteTensor with shape `(bs, num_key)` which
indicates which elements within `key` to be ignored in attention.
Default: None.
"""
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
warnings.warn(
f"position encoding of key is" f"missing in {self.__class__.__name__}."
)

assert (
query_pos is not None and key_pos is not None
), "query_pos and key_pos must be passed into ConditionalAttention Module"

# transpose (b n c) to (n b c) for attention calculation
if self.batch_first:
query = query.transpose(0, 1) # (n b c)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
query_pos = query_pos.transpose(0, 1)
key_pos = key_pos.transpose(0, 1)
identity = identity.transpose(0, 1)

# query/key/value content and position embedding projection
query_content = self.query_content_proj(query)
query_pos = self.query_pos_proj(query_pos)
key_content = self.key_content_proj(key)
key_pos = self.key_pos_proj(key_pos)
value = self.value_proj(value)

# attention calculation
N, B, C = query_content.shape
q = query_content + query_pos
k = key_content + key_pos
v = value

# hack in attention layer to implement group-detr
if self.training:
q = torch.cat(q.split(N // self.group_nums, dim=0), dim=1)
k = torch.cat(k.split(N // self.group_nums, dim=0), dim=1)
v = torch.cat(v.split(N // self.group_nums, dim=0), dim=1)

q = q.reshape(N, B, self.num_heads, C // self.num_heads).permute(
1, 2, 0, 3
) # (B, num_heads, N, head_dim)
k = k.reshape(N, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3)
v = v.reshape(N, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3)

q = q * self.scale
attn = q @ k.transpose(-2, -1)

# add attention mask
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn.masked_fill_(attn_mask, float("-inf"))
else:
attn += attn_mask
if key_padding_mask is not None:
attn = attn.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"))

attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

out = (attn @ v).transpose(1, 2).reshape(B, N, C)
out = self.out_proj(out)

if not self.batch_first:
out = out.transpose(0, 1)

if self.training:
out = torch.cat(out.split(B, dim=1), dim=0)

return identity + self.proj_drop(out)

0 comments on commit 068bf99

Please sign in to comment.