Skip to content

Commit

Permalink
Fix upsample sbp infer bug and add global test (#7884)
Browse files Browse the repository at this point in the history
* fix reduce_sum scalar check bug

* fix_unfold_tensor_sbp_and_add_global_test

* refine

* add_var_upsample_global_test

* revert unfold_tensor_op.cpp

* fix var bug about tail element handle

* fix reshape infer error for 0-dim size

* del code not belong to this branch

* fix bicubic interpolate op cuda kernel bug, test success

* fix bug

* fix segement fault bug

* make of_format

Co-authored-by: BBuf <1182563586@qq.com>
Co-authored-by: liufengwei0103 <2472937968@qq.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Apr 11, 2022
1 parent c1ad80c commit d3d7f2c
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 17 deletions.
2 changes: 1 addition & 1 deletion oneflow/user/kernels/upsample_bicubic2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class UpsampleBicubic2dGradCPUKernel final : public user_op::OpKernel {
const int64_t out_width = dy_tensor->shape().At(3);

if (in_height == out_height && in_width == out_width) {
memcpy(in_ptr, out_ptr, sizeof(T) * nbatch * channels * in_height * in_width);
memcpy(in_ptr, out_ptr, sizeof(T) * channels * in_height * in_width);
} else {
const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
Expand Down
10 changes: 5 additions & 5 deletions oneflow/user/kernels/upsample_bicubic2d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ __global__ void UpsampleBicubic2dForward(const int64_t elem_cnt, const T* in_dpt
T* out = out_dptr;

const T real_x = GetAreaPixel(scale_width, output_x, align_corners, /*cubic=*/true);
int64_t input_x = std::floor(1.0 * real_x);
int64_t input_x = floor(1.0 * real_x);
const T t_x = real_x - input_x;

const T real_y = GetAreaPixel(scale_height, output_y, align_corners, /*cubic=*/true);
int64_t input_y = std::floor(1.0 * real_y);
int64_t input_y = floor(1.0 * real_y);
const T t_y = real_y - input_y;

for (int64_t c = 0; c < channels * nbatch; c++) {
Expand Down Expand Up @@ -92,11 +92,11 @@ __global__ void UpsampleBicubic2dBackward(const int64_t elem_cnt, const T* dy_dp
const T* out = dy_dptr;

T real_x = GetAreaPixel(scale_width, output_x, align_corners, true);
int64_t input_x = std::floor(1.0 * real_x);
int64_t input_x = floor(1.0 * real_x);
T t_x = real_x - input_x;

T real_y = GetAreaPixel(scale_height, output_y, align_corners, true);
int64_t input_y = std::floor(1.0 * real_y);
int64_t input_y = floor(1.0 * real_y);
T t_y = real_y - input_y;

T x_coeffs[4];
Expand All @@ -105,7 +105,7 @@ __global__ void UpsampleBicubic2dBackward(const int64_t elem_cnt, const T* dy_dp
get_cubic_upsample_coefficients<T>(x_coeffs, t_x);
get_cubic_upsample_coefficients<T>(y_coeffs, t_y);

for (int64_t c = 0; c < channels; c++) {
for (int64_t c = 0; c < channels * nbatch; c++) {
T out_value = out[output_y * out_width + output_x];

for (int64_t i = 0; i < 4; i++) {
Expand Down
5 changes: 3 additions & 2 deletions oneflow/user/kernels/upsample_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,13 @@ OF_DEVICE_FUNC T upsample_get_value_bounded(const T* data, const int64_t width,

template<typename T>
OF_DEVICE_FUNC T cubic_convolution1(const T x, const T A) {
return ((A + 2.0) * x - (A + 3.0)) * x * x + 1.0;
return ((A + static_cast<T>(2.0)) * x - (A + static_cast<T>(3.0))) * x * x + static_cast<T>(1.0);
}

template<typename T>
OF_DEVICE_FUNC T cubic_convolution2(const T x, const T A) {
return ((A * x - 5.0 * A) * x + 8.0 * A) * x - 4.0 * A;
return ((A * x - static_cast<T>(5.0) * A) * x + static_cast<T>(8.0) * A) * x
- static_cast<T>(4.0) * A;
}

template<typename T>
Expand Down
48 changes: 40 additions & 8 deletions oneflow/user/ops/upsample_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleLinear1DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleLinear1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -241,7 +245,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleNearest1DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleNearest1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -263,7 +271,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleNearest2DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleNearest2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -285,7 +297,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleBilinear2DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleBilinear2DGradOp::InferLogicalTensorDesc(
Expand All @@ -308,7 +324,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleBicubic2DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleBicubic2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -330,7 +350,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -351,7 +375,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleNearest3DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleNearest3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand All @@ -373,7 +401,11 @@ namespace oneflow {
}

/*static*/ Maybe<void> UpsampleTrilinear3DGradOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
ctx->NewBuilder()
.Split(user_op::OpArg("dy", 0), 0)
.Split(user_op::OpArg("x", 0), 0)
.Split(user_op::OpArg("dx", 0), 0)
.Build();
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> UpsampleTrilinear3DGradOp::InferLogicalTensorDesc(
Expand Down
113 changes: 113 additions & 0 deletions python/oneflow/test/modules/test_consistent_upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest
from collections import OrderedDict

import numpy as np
from oneflow.test_utils.test_util import GenArgList
from oneflow.test_utils.automated_test_util import *

import oneflow as flow
import oneflow.unittest


@autotest(n=1, auto_backward=True, check_graph=False)
def _test_global_upsample2d_nearest(test_case, placement, sbp):
x = random_tensor(ndim=3, dim0=8, dim1=16).to_global(placement, sbp)
print(x)
m = torch.nn.Upsample(scale_factor=random().to(int), mode="nearest",)
y = m(x)
return y


@autotest(n=1, auto_backward=True, check_graph=False)
def _test_global_upsample2d_linear(test_case, placement, sbp):
x = random_tensor(ndim=3, dim0=8, dim1=16).to_global(placement, sbp)
m = torch.nn.Upsample(
scale_factor=random().to(int), mode="linear", align_corners=random_bool(),
)
y = m(x)
return y


@autotest(n=1, auto_backward=True, check_graph=False)
def _test_global_upsample2d_bilinear(test_case, placement, sbp):
x = random_tensor(ndim=4, dim0=8, dim1=16).to_global(placement, sbp)
m = torch.nn.Upsample(
scale_factor=random().to(int), mode="bilinear", align_corners=random_bool(),
)
y = m(x)
return y


@autotest(n=1, auto_backward=True, check_graph=False)
def _test_global_upsample2d_bicubic(test_case, placement, sbp):
x = random_tensor(ndim=4, dim0=8, dim1=16).to_global(placement, sbp)
m = torch.nn.Upsample(
scale_factor=random().to(int), mode="bicubic", align_corners=random_bool(),
)
y = m(x)
return y


@autotest(n=1, auto_backward=True, check_graph=False)
def _test_global_upsample2d_trilinear(test_case, placement, sbp):
x = random_tensor(ndim=5, dim0=8, dim1=16).to_global(placement, sbp)
m = torch.nn.Upsample(
scale_factor=random().to(int), mode="trilinear", align_corners=random_bool(),
)
y = m(x)
return y


class TestGlobalUpsample2d(flow.unittest.TestCase):
@unittest.skip(
"The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200"
)
@globaltest
def test_global_upsample2d_nearest(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_upsample2d_nearest(test_case, placement, sbp)

@globaltest
def test_global_upsample2d_linear(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_upsample2d_linear(test_case, placement, sbp)

@globaltest
def test_global_upsample2d_bilinear(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_upsample2d_bilinear(test_case, placement, sbp)

@globaltest
def test_global_upsample2d_bicubic(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_upsample2d_bicubic(test_case, placement, sbp)

@globaltest
def test_global_upsample2d_trilinear(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_upsample2d_trilinear(test_case, placement, sbp)


if __name__ == "__main__":
unittest.main()
14 changes: 13 additions & 1 deletion python/oneflow/test/modules/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def test_upsample2d(test_case):
"The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200"
)
@autotest()
def test_upsample2d(test_case):
def test_upsample2d_nearest(test_case):
device = random_device()
x = random_tensor().to(device)
m = torch.nn.Upsample(scale_factor=random().to(float), mode="nearest")
Expand All @@ -399,6 +399,18 @@ def test_upsample2d_bilinear(test_case):
y = m(x)
return y

@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@autotest(atol=1e-5)
def test_upsample2d_bicubic(test_case):
x = random_tensor(ndim=4, dim0=16, dim1=8).to("cuda")
m = torch.nn.Upsample(
scale_factor=random().to(float),
mode="bicubic",
align_corners=random_bool(),
)
y = m(x)
return y


if __name__ == "__main__":
unittest.main()

0 comments on commit d3d7f2c

Please sign in to comment.