Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why 'ws 1 for stand attention' in your GroupAttention code? #12

Closed
kejie-cn opened this issue May 31, 2021 · 5 comments
Closed

Why 'ws 1 for stand attention' in your GroupAttention code? #12

kejie-cn opened this issue May 31, 2021 · 5 comments

Comments

@kejie-cn
Copy link

I find that in your implementation of GroupAttention in gvt.py, you comment that 'ws 1 for stand attention'.

class GroupAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ws=1, sr_ratio=1.0):
        """
        ws 1 for stand attention
        """
        super(GroupAttention, self).__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

However, I think ws means the window size, if ws=1, than the self-attention is only performed in a 1x1 window, which is not the standard self-attention.

@cxxgtxy
Copy link
Collaborator

cxxgtxy commented May 31, 2021

This is an implementation choice.
We use 1 to stand for the standard attention. The resolution of the feature map in the last stage is 7x7. We perform standard self attention because it's cheap at that stage.

@kejie-cn
Copy link
Author

This is an implementation choice.
We use 1 to stand for the standard attention. The resolution of the feature map in the last stage is 7x7. We perform standard self attention because it's cheap at that stage.

but the standard self-attention should use a window size equals to to feature size (ws = 7 in the last stage)

@cxxgtxy
Copy link
Collaborator

cxxgtxy commented May 31, 2021

In detailed implementation, Ws=7 does not work in the last stage. Please check the code.

@xiaohu2015
Copy link

xiaohu2015 commented May 31, 2021

@cxxgtxy In your paper, the last stage you only use GSA. for 224 classification, the last stage feature size = 7x7, ws = 7 (LSA) and ws = 1 (GSA) is equal,but for detection or segmentation, the last stage feature size maybe not 7x7, ws = 7 (LSA) and ws = 1 (GSA) is not equal, dose this mean you use LSA and GSA at the same time for the last stage?

@cxxgtxy
Copy link
Collaborator

cxxgtxy commented Jun 4, 2021

@cxxgtxy In your paper, the last stage you only use GSA. for 224 classification, the last stage feature size = 7x7, ws = 7 (LSA) and ws = 1 (GSA) is equal,but for detection or segmentation, the last stage feature size maybe not 7x7, ws = 7 (LSA) and ws = 1 (GSA) is not equal, dose this mean you use LSA and GSA at the same time for the last stage?

It's a good question. The feature map size for the detection and segmentation task is indeed larger than 7*7. As for implementation, we use ws=1 (GSA) in the last stage (as classification).
Please see the code
python for k in range(len(depths)): _block = nn.ModuleList([block_cls( dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k], ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k])]) self.blocks.append(_block) cur += depths[k]

@cxxgtxy cxxgtxy closed this as completed Jun 7, 2021
littleSunlxy pushed a commit to littleSunlxy/Twins that referenced this issue Nov 4, 2021
* add pytorch2onnx part

* Update according to the latest mmcv

* add docstring

* update docs

* update docs

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants