Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix upsample sbp infer bug and add global test #7884

Merged
merged 44 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
e6e6973
fix reduce_sum scalar check bug
BBuf Mar 22, 2022
a0abdd5
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 23, 2022
00522df
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 23, 2022
9a40937
fix_unfold_tensor_sbp_and_add_global_test
clackhan Mar 23, 2022
8ff6ec6
refine
clackhan Mar 23, 2022
d38ba32
add_var_upsample_global_test
clackhan Mar 23, 2022
303c5f9
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Mar 24, 2022
cd5964d
revert unfold_tensor_op.cpp
clackhan Mar 24, 2022
c90a291
Merge branch 'master' into add_var_upsample_global_test
clackhan Mar 24, 2022
54f71f9
fix var bug about tail element handle
liufengwei0103 Mar 24, 2022
68e0e08
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 24, 2022
f154da2
fix reshape infer error for 0-dim size
clackhan Mar 24, 2022
f76f499
del code not belong to this branch
clackhan Mar 24, 2022
0b90f9b
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 28, 2022
da30b1a
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Mar 28, 2022
21741e6
fix bicubic interpolate op cuda kernel bug, test success
BBuf Mar 28, 2022
a5f4921
fix bug
BBuf Mar 29, 2022
468e775
Merge branch 'master' into add_var_upsample_global_test
clackhan Mar 29, 2022
81faf24
Merge branch 'fix_bicubic_interpolate_bug' of https://github.com/Onef…
clackhan Mar 29, 2022
fb812b7
Merge branch 'master' into add_var_upsample_global_test
BBuf Mar 29, 2022
f08840a
Merge branch 'fix_bicubic_interpolate_bug' into add_var_upsample_glob…
BBuf Mar 29, 2022
9f6551c
fix segement fault bug
BBuf Mar 29, 2022
2ab15d4
Merge branch 'add_var_upsample_global_test' of github.com:Oneflow-Inc…
BBuf Mar 29, 2022
84cf459
Merge branch 'add_var_upsample_global_test' of https://github.com/One…
clackhan Mar 29, 2022
1c7cfa7
Merge branch 'master' into add_var_upsample_global_test
clackhan Mar 29, 2022
b65823f
make of_format
clackhan Mar 29, 2022
c89fe8b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Mar 29, 2022
86db50d
Merge branch 'add_var_upsample_global_test' of https://github.com/One…
clackhan Mar 29, 2022
20ceb35
Merge branch 'master' into add_var_upsample_global_test
clackhan Apr 6, 2022
d587929
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 6, 2022
23e4199
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 6, 2022
d0023c7
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 6, 2022
bccad55
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 7, 2022
01890db
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 7, 2022
600a508
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 7, 2022
4c13987
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 7, 2022
593278a
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 7, 2022
ddd5a37
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 8, 2022
e6ccea4
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 8, 2022
81dee25
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 8, 2022
d35ce15
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 8, 2022
6db23bb
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 9, 2022
382dd5d
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 11, 2022
75e34a3
Merge branch 'master' into add_var_upsample_global_test
mergify[bot] Apr 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion oneflow/user/kernels/upsample_bicubic2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class UpsampleBicubic2dGradCPUKernel final : public user_op::OpKernel {
const int64_t out_width = dy_tensor->shape().At(3);

if (in_height == out_height && in_width == out_width) {
memcpy(in_ptr, out_ptr, sizeof(T) * nbatch * channels * in_height * in_width);
memcpy(in_ptr, out_ptr, sizeof(T) * channels * in_height * in_width);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里channels已经是原始channel和batch的乘积了,这里导致了segment fault的bug

} else {
const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
Expand Down
10 changes: 5 additions & 5 deletions oneflow/user/kernels/upsample_bicubic2d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ __global__ void UpsampleBicubic2dForward(const int64_t elem_cnt, const T* in_dpt
T* out = out_dptr;

const T real_x = GetAreaPixel(scale_width, output_x, align_corners, /*cubic=*/true);
int64_t input_x = std::floor(1.0 * real_x);
int64_t input_x = floor(1.0 * real_x);
const T t_x = real_x - input_x;

const T real_y = GetAreaPixel(scale_height, output_y, align_corners, /*cubic=*/true);
int64_t input_y = std::floor(1.0 * real_y);
int64_t input_y = floor(1.0 * real_y);
const T t_y = real_y - input_y;

for (int64_t c = 0; c < channels * nbatch; c++) {
Expand Down Expand Up @@ -92,11 +92,11 @@ __global__ void UpsampleBicubic2dBackward(const int64_t elem_cnt, const T* dy_dp
const T* out = dy_dptr;

T real_x = GetAreaPixel(scale_width, output_x, align_corners, true);
int64_t input_x = std::floor(1.0 * real_x);
int64_t input_x = floor(1.0 * real_x);
T t_x = real_x - input_x;

T real_y = GetAreaPixel(scale_height, output_y, align_corners, true);
int64_t input_y = std::floor(1.0 * real_y);
int64_t input_y = floor(1.0 * real_y);
T t_y = real_y - input_y;

T x_coeffs[4];
Expand All @@ -105,7 +105,7 @@ __global__ void UpsampleBicubic2dBackward(const int64_t elem_cnt, const T* dy_dp
get_cubic_upsample_coefficients<T>(x_coeffs, t_x);
get_cubic_upsample_coefficients<T>(y_coeffs, t_y);

for (int64_t c = 0; c < channels; c++) {
for (int64_t c = 0; c < channels * nbatch; c++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里导致了后向计算只有第一个batch生效的bug

T out_value = out[output_y * out_width + output_x];

for (int64_t i = 0; i < 4; i++) {
Expand Down
5 changes: 3 additions & 2 deletions oneflow/user/kernels/upsample_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,13 @@ OF_DEVICE_FUNC T upsample_get_value_bounded(const T* data, const int64_t width,

template<typename T>
OF_DEVICE_FUNC T cubic_convolution1(const T x, const T A) {
return ((A + 2.0) * x - (A + 3.0)) * x * x + 1.0;
return ((A + static_cast<T>(2.0)) * x - (A + static_cast<T>(3.0))) * x * x + static_cast<T>(1.0);
}

template<typename T>
OF_DEVICE_FUNC T cubic_convolution2(const T x, const T A) {
return ((A * x - 5.0 * A) * x + 8.0 * A) * x - 4.0 * A;
return ((A * x - static_cast<T>(5.0) * A) * x + static_cast<T>(8.0) * A) * x
- static_cast<T>(4.0) * A;
}

template<typename T>
Expand Down
48 changes: 40 additions & 8 deletions oneflow/user/ops/upsample_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleLinear1DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleLinear1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -241,7 +245,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleNearest1DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleNearest1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -263,7 +271,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleNearest2DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleNearest2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -285,7 +297,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleBilinear2DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleBilinear2DGradOp::InferLogicalTensorDesc(
Expand All @@ -308,7 +324,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleBicubic2DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleBicubic2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -330,7 +350,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -351,7 +375,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleNearest3DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleNearest3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -373,7 +401,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleTrilinear3DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleTrilinear3DGradOp::InferLogicalTensorDesc(
Expand Down
113 changes: 113 additions & 0 deletions python/oneflow/test/modules/test_consistent_upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest
from collections import OrderedDict

import numpy as np
from oneflow.test_utils.test_util import GenArgList
from oneflow.test_utils.automated_test_util import *

import oneflow as flow
import oneflow.unittest


@autotest(n=1, auto_backward=True, check_graph=False)
def _test_global_upsample2d_nearest(test_case, placement, sbp):
x = random_tensor(ndim=3, dim0=8, dim1=16).to_global(placement, sbp)
print(x)
m = torch.nn.Upsample(scale_factor=random().to(int), mode="nearest",)
y = m(x)
return y


@autotest(n=1, auto_backward=True, check_graph=False)
def _test_global_upsample2d_linear(test_case, placement, sbp):
x = random_tensor(ndim=3, dim0=8, dim1=16).to_global(placement, sbp)
m = torch.nn.Upsample(
scale_factor=random().to(int), mode="linear", align_corners=random_bool(),
)
y = m(x)
return y


@autotest(n=1, auto_backward=True, check_graph=False)
def _test_global_upsample2d_bilinear(test_case, placement, sbp):
x = random_tensor(ndim=4, dim0=8, dim1=16).to_global(placement, sbp)
m = torch.nn.Upsample(
scale_factor=random().to(int), mode="bilinear", align_corners=random_bool(),
)
y = m(x)
return y


@autotest(n=1, auto_backward=True, check_graph=False)
def _test_global_upsample2d_bicubic(test_case, placement, sbp):
Copy link
Contributor Author

@clackhan clackhan Mar 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bicubic模式下,oneflow 后向计算结果与pytorch对不上,不知道是否是在实现上有差异

复现命令:设置```auto_backward=True````,

python test_consistent_upsample.py --verbose --failfast

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要确定是否真的对不上,我们要以Pytorch1.11为准。如果确实对不上,那么就更新到:https://github.com/Oneflow-Inc/OneTeam/issues/1207#issuecomment-1073432125 ,我来debug。

Copy link
Contributor Author

@clackhan clackhan Mar 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要确定是否真的对不上,我们要以Pytorch1.11为准。如果确实对不上,那么就更新到:Oneflow-Inc/OneTeam#1207 (comment) ,我来debug。

升级pytorch到1.11后(原本是在1.10下侧的),后向计算结果仍然不一样,已更新在Oneflow-Inc/OneTeam#1207 (comment)

Copy link
Contributor

@BBuf BBuf Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clackhan 此bug已在#7916 中修复。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clackhan 此bug已在https://github.com/Oneflow-Inc/oneflow/pull/7916中修复。

合并 pr7916 后,打开后向测试,直接 Abroted 的了,关闭后向没有问题,可以正常跑,报错信息如下:

python test_consistent_upsample.py --verbose --failfast
test_global_upsample2d_bicubic (__main__.TestGlobalUpsample2d) ... Environment has been initialized, this env init will do nothing.
/home/hanbinbin/anaconda3/envs/oneflow/lib/python3.8/site-packages/torch/_tensor.py:1104: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:475.)
  return self._grad
free(): corrupted unsorted chunks
Aborted (core dumped)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

神奇,我再看看

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

神奇,我再看看

好的

x = random_tensor(ndim=4, dim0=8, dim1=16).to_global(placement, sbp)
m = torch.nn.Upsample(
scale_factor=random().to(int), mode="bicubic", align_corners=random_bool(),
)
y = m(x)
return y


@autotest(n=1, auto_backward=True, check_graph=False)
def _test_global_upsample2d_trilinear(test_case, placement, sbp):
x = random_tensor(ndim=5, dim0=8, dim1=16).to_global(placement, sbp)
m = torch.nn.Upsample(
scale_factor=random().to(int), mode="trilinear", align_corners=random_bool(),
)
y = m(x)
return y


class TestGlobalUpsample2d(flow.unittest.TestCase):
@unittest.skip(
"The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200"
)
@globaltest
def test_global_upsample2d_nearest(test_case):
Comment on lines +78 to +82
Copy link
Contributor Author

@clackhan clackhan Mar 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个里pytroch中的issue已经close了,但是这测结果还是和pytorch对不上,不知道只oneflow的问题还是pytorch的问题

复现命令:注释@unittest.skip

python test_consistent_upsample.py --verbose --failfast

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能和我们CI环境下的PyTorch版本比较旧有关,这里暂时也和Local一样skip吧

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch 兼容计划里面有个初步约定是以 PyTorch 1.11 作为兼容标准,按说后面CI可以同统一升级到 PyTorch 1.11 ?

@BBuf @hjchen2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@caishenghang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前我和盛航在升级的过程中碰到了很多问题,还没有来得及一一解决。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这边重复测试了一下,发现原因是因为pytorch的cpu和gpu结果在缩放系数不是整数情况下跑出的结果对不上。这个bug我之前确实反馈了,但pytorch不修就直接把我issue关了,这个问题先不管吧。

for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_upsample2d_nearest(test_case, placement, sbp)

@globaltest
def test_global_upsample2d_linear(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_upsample2d_linear(test_case, placement, sbp)

@globaltest
def test_global_upsample2d_bilinear(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_upsample2d_bilinear(test_case, placement, sbp)

@globaltest
def test_global_upsample2d_bicubic(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_upsample2d_bicubic(test_case, placement, sbp)

@globaltest
def test_global_upsample2d_trilinear(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_upsample2d_trilinear(test_case, placement, sbp)


if __name__ == "__main__":
unittest.main()
14 changes: 13 additions & 1 deletion python/oneflow/test/modules/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def test_upsample2d(test_case):
"The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200"
)
@autotest()
def test_upsample2d(test_case):
def test_upsample2d_nearest(test_case):
device = random_device()
x = random_tensor().to(device)
m = torch.nn.Upsample(scale_factor=random().to(float), mode="nearest")
Expand All @@ -399,6 +399,18 @@ def test_upsample2d_bilinear(test_case):
y = m(x)
return y

@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@autotest(atol=1e-5)
def test_upsample2d_bicubic(test_case):
x = random_tensor(ndim=4, dim0=16, dim1=8).to("cuda")
m = torch.nn.Upsample(
scale_factor=random().to(float),
mode="bicubic",
align_corners=random_bool(),
)
y = m(x)
return y


if __name__ == "__main__":
unittest.main()