-
Notifications
You must be signed in to change notification settings - Fork 755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GroupNorm #5175
Add GroupNorm #5175
Conversation
if __name__ == "__main__": | ||
import doctest | ||
|
||
doctest.testmod() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文末换行
from oneflow.python.nn.module import Module | ||
|
||
|
||
@oneflow_export("nn.GroupNorm") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要在 oneflow/python/test/modules
目录添加对应的 test_xxx
文件,测试
The documentation is referenced from: | ||
https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html | ||
|
||
Applies Group Normalization over a mini-batch of inputs as described in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要编译并截图 docstring
https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html | ||
|
||
Applies Group Normalization over a mini-batch of inputs as described in | ||
the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种超链接形式可能在我们的编译过程中不能很好支持,请留意编译后的结果,并且确保编译过程中没有 warning(否则现在也通过不了CI)
assert (input.shape[1] == self.num_channels), "The channels of input tensor must equal num_channels" | ||
origin_shape = input.shape | ||
reshape_to_1d = flow.experimental.reshape(input, shape=[origin_shape[0], self.num_groups, -1]) | ||
(mean, variance) = flow.experimental.nn.moments(reshape_to_1d, [2], keepdims=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch 没有 nn.momentd
, OneFlow 也没有迁移 flow.experimental.nn.moments
。
这里可以用已经迁移的 flow.experimental.mean
和还没有迁移的 flow.experimental.var
(但是已经有对应的老接口 flow.reduce_variance
分别算出来。不使用 nn.moments
。
所以这个PR里,建议同时把 flow.experimental.var
也给对齐、搬运了(这也是为什么我说真的需求会在搬运过程中发现)。
PS:像 flow.experimental.nn.moments
其实不存在这种问题,自己试一试就会发现,所以搬运过程中还是要把自己力所能及的test给跑通比较好,自己就会发现问题,这样可以省掉来回沟通的时间成本
…into lcy_groupnorm
test_case.assertTrue(np.allclose(y.numpy(), output, 1e-5, 1e-5)) | ||
|
||
|
||
def _test_groupnorm_v2(test_case, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个名字可以用更有含义的代替么 v2
的信息太少了,比如 smaller_shape
什么的
flow.nn.init.zeros_(self.bias) | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
assert len(input.shape) >= 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert后需要加一下提示信息
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加
and zeros (for biases). Default: ``True``. | ||
|
||
Shape: | ||
- Input: :math:`(N, C, *)` where :math:`C=\text{num_channels}` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个公式显示好像不正常
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里已经修改过了,我更新一下截图
>>> # Put all 6 channels into a single group (equivalent with LayerNorm) | ||
>>> m = flow.nn.GroupNorm(1, 6) | ||
>>> # Activating the module | ||
>>> output = m(input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看一下每种情况的输出shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
每种情况输出的shape都一样
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好吧
from test_util import GenArgList | ||
|
||
|
||
input_arr = np.array( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是多余的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,我去删掉
], | ||
dtype=np.float32, | ||
) | ||
output = np.array( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉前向的输出可以用numpy模拟一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
别的instancenorm,batchnorm都是这样直接写的
Add GroupNorm Module
![GroupNorm](https://user-images.githubusercontent.com/84563719/122495704-192de800-d01d-11eb-9d54-198c71236baa.png)
GroupNorm Doctest
![test_groupnorm](https://user-images.githubusercontent.com/84563719/122150588-96782200-ce90-11eb-829e-4652b70cb4c6.png)