-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* support group detr * fix matcher bug * refine readme * refine license Co-authored-by: ntianhe ren <rentianhe@dgx061.scc.idea>
- Loading branch information
Showing
12 changed files
with
1,229 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
## Group DETR: Fast DETR Training with Group-Wise One-to-Many Assignment | ||
|
||
Chen, Qiang and Chen, Xiaokang and Wang, Jian and Feng, Haocheng and Han, Junyu and Ding, Errui and Zeng, Gang and Wang, Jingdong | ||
|
||
[[`arXiv`](https://arxiv.org/abs/2207.13085)] [[`BibTeX`](#citing-conditional-detr)] | ||
|
||
<div align="center"> | ||
<img src="./assets/group_detr_arch.png"/> | ||
</div><br/> | ||
|
||
**Note**: This is the implementation of `Conditional DETR + Group DETR` | ||
|
||
## Training | ||
All configs can be trained with: | ||
```bash | ||
cd detrex | ||
python tools/train_net.py --config-file projects/group_detr/configs/path/to/config.py --num-gpus 8 | ||
``` | ||
By default, we use 8 GPUs with total batch size as 16 for training. | ||
|
||
## Evaluation | ||
Model evaluation can be done as follows: | ||
```bash | ||
cd detrex | ||
python tools/train_net.py --config-file projects/group_detr/configs/path/to/config.py --eval-only train.init_checkpoint=/path/to/model_checkpoint | ||
``` | ||
|
||
## Citing Group-DETR | ||
If you find our work helpful for your research, please consider citing the following BibTeX entry. | ||
|
||
```BibTex | ||
@article{chen2022group, | ||
title={Group DETR: Fast DETR Training with Group-Wise One-to-Many Assignment}, | ||
author={Chen, Qiang and Chen, Xiaokang and Wang, Jian and Feng, Haocheng and Han, Junyu and Ding, Errui and Zeng, Gang and Wang, Jingdong}, | ||
journal={arXiv preprint arXiv:2207.13085}, | ||
year={2022} | ||
} | ||
``` |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from detrex.config import get_config | ||
from .models.group_detr_r50 import model | ||
|
||
dataloader = get_config("common/data/coco_detr.py").dataloader | ||
optimizer = get_config("common/optim.py").AdamW | ||
lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_50ep | ||
train = get_config("common/train.py").train | ||
|
||
# modify training config | ||
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl" | ||
train.output_dir = "./output/group_detr_r50_50ep" | ||
train.max_iter = 375000 | ||
train.clip_grad.enabled = True | ||
train.clip_grad.params.max_norm = 0.1 | ||
train.clip_grad.params.norm_type = 2 | ||
|
||
# modify optimizer config | ||
optimizer.weight_decay = 1e-4 | ||
optimizer.params.lr_factor_func = lambda module_name: 0.1 if "backbone" in module_name else 1 | ||
|
||
# modify dataloader config | ||
dataloader.train.num_workers = 16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import torch.nn as nn | ||
|
||
from detrex.layers import PositionEmbeddingSine | ||
from detrex.modeling.backbone import ResNet, BasicStem | ||
|
||
from detectron2.config import LazyCall as L | ||
|
||
from projects.group_detr.modeling import ( | ||
GroupDETR, | ||
GroupDetrTransformer, | ||
GroupDetrTransformerDecoder, | ||
GroupDetrTransformerEncoder, | ||
GroupHungarianMatcher, | ||
GroupSetCriterion, | ||
) | ||
|
||
|
||
model = L(GroupDETR)( | ||
backbone=L(ResNet)( | ||
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), | ||
stages=L(ResNet.make_default_stages)( | ||
depth=50, | ||
stride_in_1x1=False, | ||
norm="FrozenBN", | ||
), | ||
out_features=["res2", "res3", "res4", "res5"], | ||
freeze_at=1, | ||
), | ||
in_features=["res5"], # only use last level feature in Conditional-DETR | ||
in_channels=2048, | ||
position_embedding=L(PositionEmbeddingSine)( | ||
num_pos_feats=128, | ||
temperature=10000, | ||
normalize=True, | ||
), | ||
transformer=L(GroupDetrTransformer)( | ||
encoder=L(GroupDetrTransformerEncoder)( | ||
embed_dim=256, | ||
num_heads=8, | ||
attn_dropout=0.1, | ||
feedforward_dim=2048, | ||
ffn_dropout=0.1, | ||
activation=L(nn.ReLU)(), | ||
num_layers=6, | ||
post_norm=False, | ||
), | ||
decoder=L(GroupDetrTransformerDecoder)( | ||
embed_dim=256, | ||
num_heads=8, | ||
attn_dropout=0.1, | ||
feedforward_dim=2048, | ||
ffn_dropout=0.1, | ||
activation=L(nn.ReLU)(), | ||
num_layers=6, | ||
group_nums="${...group_nums}", | ||
post_norm=True, | ||
return_intermediate=True, | ||
), | ||
), | ||
embed_dim=256, | ||
num_classes=80, | ||
num_queries=300, | ||
criterion=L(GroupSetCriterion)( | ||
num_classes=80, | ||
matcher=L(GroupHungarianMatcher)( | ||
cost_class=2.0, | ||
cost_bbox=5.0, | ||
cost_giou=2.0, | ||
), | ||
weight_dict={ | ||
"loss_class": 2.0, | ||
"loss_bbox": 5.0, | ||
"loss_giou": 2.0, | ||
}, | ||
group_nums="${..group_nums}", | ||
alpha=0.25, | ||
gamma=2.0, | ||
), | ||
aux_loss=True, | ||
group_nums=11, | ||
pixel_mean=[123.675, 116.280, 103.530], | ||
pixel_std=[58.395, 57.120, 57.375], | ||
select_box_nums_for_evaluation=300, | ||
device="cuda", | ||
) | ||
|
||
# set aux loss weight dict | ||
if model.aux_loss: | ||
weight_dict = model.criterion.weight_dict | ||
aux_weight_dict = {} | ||
for i in range(model.transformer.decoder.num_layers - 1): | ||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) | ||
weight_dict.update(aux_weight_dict) | ||
model.criterion.weight_dict = weight_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .group_detr import GroupDETR | ||
from .group_detr_transformer import ( | ||
GroupDetrTransformerEncoder, | ||
GroupDetrTransformerDecoder, | ||
GroupDetrTransformer, | ||
) | ||
from .attention import GroupConditionalSelfAttention | ||
from .group_criterion import GroupSetCriterion | ||
from .group_matcher import GroupHungarianMatcher |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 The IDEA Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import warnings | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class GroupConditionalSelfAttention(nn.Module): | ||
"""Conditional Self-Attention Module used in Group-DETR | ||
`Conditional DETR for Fast Training Convergence. | ||
<https://arxiv.org/pdf/2108.06152.pdf>`_ | ||
Args: | ||
embed_dim (int): The embedding dimension for attention. | ||
num_heads (int): The number of attention heads. | ||
attn_drop (float): A Dropout layer on attn_output_weights. | ||
Default: 0.0. | ||
proj_drop (float): A Dropout layer after `MultiheadAttention`. | ||
Default: 0.0. | ||
batch_first (bool): if `True`, then the input and output tensor will be | ||
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
embed_dim, | ||
num_heads, | ||
attn_drop=0.0, | ||
proj_drop=0.0, | ||
group_nums=11, | ||
batch_first=False, | ||
**kwargs, | ||
): | ||
super(GroupConditionalSelfAttention, self).__init__() | ||
self.query_content_proj = nn.Linear(embed_dim, embed_dim) | ||
self.query_pos_proj = nn.Linear(embed_dim, embed_dim) | ||
self.key_content_proj = nn.Linear(embed_dim, embed_dim) | ||
self.key_pos_proj = nn.Linear(embed_dim, embed_dim) | ||
self.value_proj = nn.Linear(embed_dim, embed_dim) | ||
self.out_proj = nn.Linear(embed_dim, embed_dim) | ||
self.attn_drop = nn.Dropout(attn_drop) | ||
self.proj_drop = nn.Dropout(proj_drop) | ||
self.group_nums = group_nums | ||
self.num_heads = num_heads | ||
self.embed_dim = embed_dim | ||
head_dim = embed_dim // num_heads | ||
self.scale = head_dim**-0.5 | ||
self.batch_first = batch_first | ||
|
||
def forward( | ||
self, | ||
query, | ||
key=None, | ||
value=None, | ||
identity=None, | ||
query_pos=None, | ||
key_pos=None, | ||
attn_mask=None, | ||
key_padding_mask=None, | ||
**kwargs, | ||
): | ||
"""Forward function for `ConditionalSelfAttention` | ||
**kwargs allow passing a more general data flow when combining | ||
with other operations in `transformerlayer`. | ||
Args: | ||
query (torch.Tensor): Query embeddings with shape | ||
`(num_query, bs, embed_dim)` if self.batch_first is False, | ||
else `(bs, num_query, embed_dim)` | ||
key (torch.Tensor): Key embeddings with shape | ||
`(num_key, bs, embed_dim)` if self.batch_first is False, | ||
else `(bs, num_key, embed_dim)` | ||
value (torch.Tensor): Value embeddings with the same shape as `key`. | ||
Same in `torch.nn.MultiheadAttention.forward`. Default: None. | ||
If None, the `key` will be used. | ||
identity (torch.Tensor): The tensor, with the same shape as `query``, | ||
which will be used for identity addition. Default: None. | ||
If None, `query` will be used. | ||
query_pos (torch.Tensor): The position embedding for query, with the | ||
same shape as `query`. Default: None. | ||
key_pos (torch.Tensor): The position embedding for key. Default: None. | ||
If None, and `query_pos` has the same shape as `key`, then `query_pos` | ||
will be used for `key_pos`. | ||
attn_mask (torch.Tensor): ByteTensor mask with shape `(num_query, num_key)`. | ||
Same as `torch.nn.MultiheadAttention.forward`. Default: None. | ||
key_padding_mask (torch.Tensor): ByteTensor with shape `(bs, num_key)` which | ||
indicates which elements within `key` to be ignored in attention. | ||
Default: None. | ||
""" | ||
if key is None: | ||
key = query | ||
if value is None: | ||
value = key | ||
if identity is None: | ||
identity = query | ||
if key_pos is None: | ||
if query_pos is not None: | ||
# use query_pos if key_pos is not available | ||
if query_pos.shape == key.shape: | ||
key_pos = query_pos | ||
else: | ||
warnings.warn( | ||
f"position encoding of key is" f"missing in {self.__class__.__name__}." | ||
) | ||
|
||
assert ( | ||
query_pos is not None and key_pos is not None | ||
), "query_pos and key_pos must be passed into ConditionalAttention Module" | ||
|
||
# transpose (b n c) to (n b c) for attention calculation | ||
if self.batch_first: | ||
query = query.transpose(0, 1) # (n b c) | ||
key = key.transpose(0, 1) | ||
value = value.transpose(0, 1) | ||
query_pos = query_pos.transpose(0, 1) | ||
key_pos = key_pos.transpose(0, 1) | ||
identity = identity.transpose(0, 1) | ||
|
||
# query/key/value content and position embedding projection | ||
query_content = self.query_content_proj(query) | ||
query_pos = self.query_pos_proj(query_pos) | ||
key_content = self.key_content_proj(key) | ||
key_pos = self.key_pos_proj(key_pos) | ||
value = self.value_proj(value) | ||
|
||
# attention calculation | ||
N, B, C = query_content.shape | ||
q = query_content + query_pos | ||
k = key_content + key_pos | ||
v = value | ||
|
||
# hack in attention layer to implement group-detr | ||
if self.training: | ||
q = torch.cat(q.split(N // self.group_nums, dim=0), dim=1) | ||
k = torch.cat(k.split(N // self.group_nums, dim=0), dim=1) | ||
v = torch.cat(v.split(N // self.group_nums, dim=0), dim=1) | ||
|
||
q = q.reshape(N, B, self.num_heads, C // self.num_heads).permute( | ||
1, 2, 0, 3 | ||
) # (B, num_heads, N, head_dim) | ||
k = k.reshape(N, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3) | ||
v = v.reshape(N, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3) | ||
|
||
q = q * self.scale | ||
attn = q @ k.transpose(-2, -1) | ||
|
||
# add attention mask | ||
if attn_mask is not None: | ||
if attn_mask.dtype == torch.bool: | ||
attn.masked_fill_(attn_mask, float("-inf")) | ||
else: | ||
attn += attn_mask | ||
if key_padding_mask is not None: | ||
attn = attn.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")) | ||
|
||
attn = attn.softmax(dim=-1) | ||
attn = self.attn_drop(attn) | ||
|
||
out = (attn @ v).transpose(1, 2).reshape(B, N, C) | ||
out = self.out_proj(out) | ||
|
||
if not self.batch_first: | ||
out = out.transpose(0, 1) | ||
|
||
if self.training: | ||
out = torch.cat(out.split(B, dim=1), dim=0) | ||
|
||
return identity + self.proj_drop(out) |
Oops, something went wrong.