Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for convtranspose2d #5239

Merged
merged 26 commits into from
Jun 21, 2021
Merged

Conversation

BBuf
Copy link
Contributor

@BBuf BBuf commented Jun 18, 2021

实现group convtranspose2d功能。

self.weight = flow.nn.Parameter(
flow.Tensor(in_channels, out_channels // groups, *kernel_size)
flow.Tensor(out_channels // groups, in_channels, *kernel_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flow.Tensor(in_channels, out_channels // groups, *kernel_size) 应该是这样

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修正

out_list = []
for i in range(len(in_split_list)):
out_list.append(
self._op(in_split_list[i], self.weight[:, i : i + 1, :, :])[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 weight 取错了?应该是 self.weight[i * groups : (i+1) * groups, :, :, :])[0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方应该是是没有取错的,上面权重的shape已经是flow.Tensor(in_channels, out_channels // groups, *kernel_size),相当于这个通道已经除了groups了。

@oneflow-ci-bot oneflow-ci-bot removed their request for review June 18, 2021 12:00
out_list = []
for i in range(len(in_split_list)):
out_list.append(
self._op(in_split_list[i], self.weight[i : (i + 1), :, :, :])[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还是有问题,你测试用例里面 输入通道没有大于2的,测不出来问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改。

@BBuf BBuf requested a review from oneflow-ci-bot June 20, 2021 10:23
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot June 21, 2021 01:38
@oneflow-ci-bot oneflow-ci-bot merged commit 903efa1 into master Jun 21, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the add_test_for_convtranspose2d branch June 21, 2021 02:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants