diff --git a/docs/source/functional.rst b/docs/source/functional.rst index dd15c190d71..3bd7b53d322 100644 --- a/docs/source/functional.rst +++ b/docs/source/functional.rst @@ -31,3 +31,5 @@ Functional operations for neural networks .. autofunction:: one_hot .. autofunction:: dropout .. autofunction:: upsample +.. autofunction:: affine_grid +.. autofunction:: grid_sample diff --git a/oneflow/core/autograd/gradient_funcs/affine_grid.cpp b/oneflow/core/autograd/gradient_funcs/affine_grid.cpp new file mode 100644 index 00000000000..ae0b21c775d --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/affine_grid.cpp @@ -0,0 +1,69 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct AffineGridInterpState : public AutoGradCaptureState { + Shape size; + bool align_corners = false; + bool requires_grad = false; +}; + +class AffineGrid : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(AffineGridInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); // theta + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->size = JUST(composed_attrs.GetAttr("size")); + ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); + return Maybe::Ok(); + } + + Maybe Apply(const AffineGridInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + if (!ctx->requires_grad) { return Maybe::Ok(); } + + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + in_grads->at(0) = + JUST(functional::AffineGridGrad(out_grads.at(0), ctx->size, ctx->align_corners)); + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("affine_grid", AffineGrid); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/grid_sample.cpp b/oneflow/core/autograd/gradient_funcs/grid_sample.cpp new file mode 100644 index 00000000000..33bbff757d5 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/grid_sample.cpp @@ -0,0 +1,86 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct GridSampleInterpState : public AutoGradCaptureState { + std::string interpolation_mode = ""; + std::string padding_mode = ""; + bool align_corners = false; + size_t input_index = -1; + size_t grid_index = -1; + bool input_requires_grad = false; + bool grid_requires_grad = false; + bool requires_grad = false; +}; + +class GridSample : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(GridSampleInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 2); + ctx->input_requires_grad = inputs.at(0)->requires_grad(); + ctx->grid_requires_grad = inputs.at(1)->requires_grad(); + ctx->requires_grad = ctx->input_requires_grad || ctx->grid_requires_grad; + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); // input + ctx->grid_index = ctx->SaveTensorForBackward(inputs.at(1)); // grid + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->interpolation_mode = JUST(composed_attrs.GetAttr("interpolation_mode")); + ctx->padding_mode = JUST(composed_attrs.GetAttr("padding_mode")); + ctx->align_corners = JUST(composed_attrs.GetAttr("align_corners")); + return Maybe::Ok(); + } + + Maybe Apply(const GridSampleInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + if (!ctx->requires_grad) { return Maybe::Ok(); } + + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + + const auto& input = ctx->SavedTensors().at(ctx->input_index); + const auto& grid = ctx->SavedTensors().at(ctx->grid_index); + const auto& results = + JUST(functional::GridSampleGrad(out_grads.at(0), input, grid, ctx->interpolation_mode, + ctx->padding_mode, ctx->align_corners)); + in_grads->resize(2); + if (ctx->input_requires_grad) { in_grads->at(0) = results->at(0); } + if (ctx->grid_requires_grad) { in_grads->at(1) = results->at(1); } + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("grid_sample", GridSample); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 57381ca2cf3..00218f9c5db 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -519,6 +519,26 @@ "Tensor (Tensor dy, Tensor label, Tensor theta, Float m1, Float m2, Float m3, Int64 depth) => CombinedMarginLossGrad" bind_python: False +- name: "affine_grid" + signature: + "Tensor (Tensor theta, *, Shape size, Bool align_corners) => AffineGrid" + bind_python: True + +- name: "affine_grid_grad" + signature: + "Tensor (Tensor dgrid, *, Shape size, Bool align_corners) => AffineGridGrad" + bind_python: False + +- name: "grid_sample" + signature: + "Tensor (Tensor input, Tensor grid, *, String interpolation_mode, String padding_mode, Bool align_corners) => GridSample" + bind_python: True + +- name: "grid_sample_grad" + signature: + "TensorTuple (Tensor doutput, Tensor input, Tensor grid, *, String interpolation_mode, String padding_mode, Bool align_corners) => GridSampleGrad" + bind_python: False + - name: "where" signature: [ "Tensor (Tensor condition, Tensor x, Tensor y) => Where", diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index b34e74cf444..06e85da3d3d 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -459,6 +459,44 @@ class CombinedMarginLossFunctor { std::shared_ptr op_; }; +class AffineGridFunctor { + public: + AffineGridFunctor() { + op_ = CHECK_JUST(one::OpBuilder("affine_grid").Input("theta").Output("grid").Build()); + } + Maybe operator()(const std::shared_ptr& theta, const Shape& size, + const bool& align_corners) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("size", size)); + JUST(attrs.SetAttr("align_corners", align_corners)); + return OpInterpUtil::Dispatch(*op_, {theta}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class GridSampleFunctor { + public: + GridSampleFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("grid_sample").Input("input").Input("grid").Output("output").Build()); + } + Maybe operator()(const std::shared_ptr& input, + const std::shared_ptr& grid, + const std::string& interpolation_mode, const std::string& padding_mode, + const bool& align_corners) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("interpolation_mode", interpolation_mode)); + JUST(attrs.SetAttr("padding_mode", padding_mode)); + JUST(attrs.SetAttr("align_corners", align_corners)); + return OpInterpUtil::Dispatch(*op_, {input, grid}, attrs); + } + + private: + std::shared_ptr op_; +}; + class NormalizationFunctor { public: NormalizationFunctor() { @@ -998,6 +1036,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("SoftmaxCrossEntropyGrad"); m.add_functor("SmoothL1Loss"); m.add_functor("CombinedMarginLoss"); + m.add_functor("AffineGrid"); + m.add_functor("GridSample"); m.add_functor("Normalization"); m.add_functor("Pad"); m.add_functor("Dropout"); diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index a77e9c9175b..75725b5752f 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -276,6 +276,50 @@ class CombinedMarginLossGradFunctor { std::shared_ptr op_; }; +class AffineGridGradFunctor { + public: + AffineGridGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("affine_grid_grad").Input("dgrid").Output("dtheta").Build()); + } + Maybe operator()(const std::shared_ptr& dgrid, const Shape& size, + const bool& align_corners) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("size", size)); + JUST(attrs.SetAttr("align_corners", align_corners)); + return OpInterpUtil::Dispatch(*op_, {dgrid}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class GridSampleGradFunctor { + public: + GridSampleGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("grid_sample_grad") + .Input("doutput") + .Input("input") + .Input("grid") + .Output("dinput") + .Output("dgrid") + .Build()); + } + Maybe operator()(const std::shared_ptr& doutput, + const std::shared_ptr& input, + const std::shared_ptr& grid, + const std::string& interpolation_mode, + const std::string& padding_mode, const bool& align_corners) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("interpolation_mode", interpolation_mode)); + JUST(attrs.SetAttr("padding_mode", padding_mode)); + JUST(attrs.SetAttr("align_corners", align_corners)); + return OpInterpUtil::Dispatch(*op_, {doutput, input, grid}, attrs); + } + + private: + std::shared_ptr op_; +}; + class PadGradFunctor { public: PadGradFunctor() { @@ -412,6 +456,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("AdaptivePoolNdGrad"); m.add_functor("SmoothL1LossGrad"); m.add_functor("CombinedMarginLossGrad"); + m.add_functor("AffineGridGrad"); + m.add_functor("GridSampleGrad"); m.add_functor("PoolingNdGrad"); m.add_functor("PadGrad"); m.add_functor("AvgPoolingNdGrad"); diff --git a/oneflow/user/kernels/affine_grid_kernel.cpp b/oneflow/user/kernels/affine_grid_kernel.cpp new file mode 100644 index 00000000000..00954d40f5d --- /dev/null +++ b/oneflow/user/kernels/affine_grid_kernel.cpp @@ -0,0 +1,166 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/core/framework/config_def.h" +#include "affine_grid_kernel.h" + +namespace oneflow { + +template +class AffineGridKernel final : public user_op::OpKernel { + public: + AffineGridKernel() = default; + ~AffineGridKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* theta = ctx->Tensor4ArgNameAndIndex("theta", 0); + user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex("grid", 0); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + const Shape& size = ctx->Attr("size"); + const bool& align_corners = ctx->Attr("align_corners"); + bool is_2d_grid = true; + if (size.NumAxes() == 5) { is_2d_grid = false; } + + int64_t theta_h = theta->shape().At(1); + int64_t theta_w = theta->shape().At(2); + + if (is_2d_grid) { + int64_t N = size.At(0); + int64_t H = size.At(2); + int64_t W = size.At(3); + // generate base grid + GenerateBaseGridImp::Generate2D(ctx, tmp_buffer->mut_dptr(), H, W, + align_corners); + // Compute each batch + for (int n = 0; n < N; n++) { + NewKernelUtil::OFGemm(ctx->device_ctx(), CblasNoTrans, CblasTrans, H * W, + theta_h, theta_w, 1.0, tmp_buffer->dptr(), + theta->dptr() + n * theta_h * theta_w, 0.0, + grid->mut_dptr() + n * theta_h * H * W); + } + } else { + int64_t N = size.At(0); + int64_t D = size.At(2); + int64_t H = size.At(3); + int64_t W = size.At(4); + // generate base grid + GenerateBaseGridImp::Generate3D(ctx, tmp_buffer->mut_dptr(), D, H, W, + align_corners); + // Compute each batch + for (int n = 0; n < N; n++) { + NewKernelUtil::OFGemm(ctx->device_ctx(), CblasNoTrans, CblasTrans, D * H * W, + theta_h, theta_w, 1.0, tmp_buffer->dptr(), + theta->dptr() + n * theta_h * theta_w, 0.0, + grid->mut_dptr() + n * theta_h * D * H * W); + } + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_AFFINE_GRID_KERNEL(device, dtype) \ + REGISTER_USER_KERNEL("affine_grid") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("theta", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ + const Shape& size = ctx->Attr("size"); \ + size_t tmp_buffer_size = size.Count(2) * (size.NumAxes() - 1) * sizeof(dtype); \ + return tmp_buffer_size; \ + }) + +REGISTER_AFFINE_GRID_KERNEL(DeviceType::kCPU, float); +REGISTER_AFFINE_GRID_KERNEL(DeviceType::kCPU, double); +#ifdef WITH_CUDA +REGISTER_AFFINE_GRID_KERNEL(DeviceType::kGPU, float); +REGISTER_AFFINE_GRID_KERNEL(DeviceType::kGPU, double); +#endif + +template +class AffineGridGradKernel final : public user_op::OpKernel { + public: + AffineGridGradKernel() = default; + ~AffineGridGradKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* dgrid = ctx->Tensor4ArgNameAndIndex("dgrid", 0); + user_op::Tensor* dtheta = ctx->Tensor4ArgNameAndIndex("dtheta", 0); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + const Shape& size = ctx->Attr("size"); + const bool& align_corners = ctx->Attr("align_corners"); + bool is_2d_grid = true; + if (size.NumAxes() == 5) { is_2d_grid = false; } + + int64_t dtheta_h = dtheta->shape().At(1); + int64_t dtheta_w = dtheta->shape().At(2); + + if (is_2d_grid) { + int64_t N = size.At(0); + int64_t H = size.At(2); + int64_t W = size.At(3); + // generate base grid + GenerateBaseGridImp::Generate2D(ctx, tmp_buffer->mut_dptr(), H, W, + align_corners); + // Compute each batch + for (int n = 0; n < N; n++) { + NewKernelUtil::OFGemm( + ctx->device_ctx(), CblasTrans, CblasNoTrans, dtheta_h, dtheta_w, H * W, 1.0, + dgrid->dptr() + n * dtheta_h * H * W, tmp_buffer->dptr(), 0.0, + dtheta->mut_dptr() + n * dtheta_h * dtheta_w); + } + } else { + int64_t N = size.At(0); + int64_t D = size.At(2); + int64_t H = size.At(3); + int64_t W = size.At(4); + GenerateBaseGridImp::Generate3D(ctx, tmp_buffer->mut_dptr(), D, H, W, + align_corners); + // Compute each batch + for (int n = 0; n < N; n++) { + NewKernelUtil::OFGemm( + ctx->device_ctx(), CblasTrans, CblasNoTrans, dtheta_h, dtheta_w, D * H * W, 1.0, + dgrid->dptr() + n * dtheta_h * D * H * W, tmp_buffer->dptr(), 0.0, + dtheta->mut_dptr() + n * dtheta_h * dtheta_w); + } + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_AFFINE_GRID_GRAD_KERNEL(device, dtype) \ + REGISTER_USER_KERNEL("affine_grid_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("dgrid", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ + const Shape& size = ctx->Attr("size"); \ + size_t tmp_buffer_size = size.Count(2) * (size.NumAxes() - 1) * sizeof(dtype); \ + return tmp_buffer_size; \ + }) + +REGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kCPU, float); +REGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kCPU, double); +#ifdef WITH_CUDA +REGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kGPU, float); +REGISTER_AFFINE_GRID_GRAD_KERNEL(DeviceType::kGPU, double); +#endif + +} // namespace oneflow diff --git a/oneflow/user/kernels/affine_grid_kernel.cu b/oneflow/user/kernels/affine_grid_kernel.cu new file mode 100644 index 00000000000..b563fc2822c --- /dev/null +++ b/oneflow/user/kernels/affine_grid_kernel.cu @@ -0,0 +1,131 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/device/cuda_util.h" +#include "affine_grid_kernel.h" + +namespace oneflow { + +namespace { + +template +OF_DEVICE_FUNC data_type LinspaceGPU(int32_t index, int32_t num_steps) { + if (num_steps <= 1) { return static_cast(0.0); } + + if (align_corners) { + return static_cast(-1.0 + 2.0 / (num_steps - 1) * index); + } else { + return static_cast((-1.0 + 2.0 / (num_steps - 1) * index) * (num_steps - 1) + / num_steps); + } +} + +template +__global__ void Generate2DBaseGridGPUKernel(const int32_t nthreads, data_type* grid_ptr, int32_t H, + int32_t W) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int32_t h = index / W; + const int32_t w = index % W; + const int32_t pixel_length = 3; + data_type* row_ptr = grid_ptr + h * W * pixel_length; + data_type* pixel_ptr = row_ptr + w * pixel_length; + data_type h_value = LinspaceGPU(h, H); + data_type w_value = LinspaceGPU(w, W); + + pixel_ptr[0] = w_value; + pixel_ptr[1] = h_value; + pixel_ptr[2] = static_cast(1.0); + } +} + +template +__global__ void Generate3DBaseGridGPUKernel(const int32_t nthreads, data_type* grid_ptr, int32_t D, + int32_t H, int32_t W) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int32_t d = index / H; + const int32_t h = index % H; + const int32_t pixel_length = 4; + data_type* image_ptr = grid_ptr + d * H * W * pixel_length; + data_type* row_ptr = image_ptr + h * W * pixel_length; + data_type d_value = LinspaceGPU(d, D); + data_type h_value = LinspaceGPU(h, H); + + for (int32_t w = 0; w < W; ++w) { + data_type* pixel_ptr = row_ptr + w * pixel_length; + data_type w_value = LinspaceGPU(w, W); + pixel_ptr[0] = w_value; + pixel_ptr[1] = h_value; + pixel_ptr[2] = d_value; + pixel_ptr[3] = static_cast(1.0); + } + } +} + +} // namespace + +void GenerateBaseGridImp::Generate2D(user_op::KernelComputeContext* ctx, + float* grid_ptr, int64_t H, int64_t W, + bool align_corners) { + int count = H * W; + if (align_corners) { + RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel), ctx->device_ctx(), count, count, + grid_ptr, H, W); + } else { + RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel), ctx->device_ctx(), count, count, + grid_ptr, H, W); + } +} +void GenerateBaseGridImp::Generate2D(user_op::KernelComputeContext* ctx, + double* grid_ptr, int64_t H, int64_t W, + bool align_corners) { + int count = H * W; + if (align_corners) { + RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel), ctx->device_ctx(), count, count, + grid_ptr, H, W); + } else { + RUN_CUDA_KERNEL((Generate2DBaseGridGPUKernel), ctx->device_ctx(), count, count, + grid_ptr, H, W); + } +} + +void GenerateBaseGridImp::Generate3D(user_op::KernelComputeContext* ctx, + float* grid_ptr, int64_t D, int64_t H, + int64_t W, bool align_corners) { + int count = D * H; + if (align_corners) { + RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel), ctx->device_ctx(), count, count, + grid_ptr, D, H, W); + } else { + RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel), ctx->device_ctx(), count, count, + grid_ptr, D, H, W); + } +} + +void GenerateBaseGridImp::Generate3D(user_op::KernelComputeContext* ctx, + double* grid_ptr, int64_t D, int64_t H, + int64_t W, bool align_corners) { + int count = D * H; + if (align_corners) { + RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel), ctx->device_ctx(), count, count, + grid_ptr, D, H, W); + } else { + RUN_CUDA_KERNEL((Generate3DBaseGridGPUKernel), ctx->device_ctx(), count, count, + grid_ptr, D, H, W); + } +} + +} // namespace oneflow diff --git a/oneflow/user/kernels/affine_grid_kernel.h b/oneflow/user/kernels/affine_grid_kernel.h new file mode 100644 index 00000000000..50617ee264d --- /dev/null +++ b/oneflow/user/kernels/affine_grid_kernel.h @@ -0,0 +1,108 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_ +#define _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_ + +#include "oneflow/core/framework/op_kernel.h" +#include "oneflow/core/common/device_type.h" + +namespace oneflow { + +template +struct GenerateBaseGridImp {}; + +template<> +struct GenerateBaseGridImp { + template + static void Linspace(std::vector& grid, int64_t num_steps, bool align_corners) { + if (num_steps <= 1) { + for (auto& it : grid) { it = static_cast(0.0); } + return; + } + + if (align_corners) { + for (int i = 0; i < num_steps; i++) { + grid[i] = static_cast(-1.0 + 2.0 / (num_steps - 1) * i); + } + } else { + for (int i = 0; i < num_steps; i++) { + grid[i] = static_cast((-1.0 + 2.0 / (num_steps - 1) * i) * (num_steps - 1) + / num_steps); + } + } + } + + template + static void Generate2D(user_op::KernelComputeContext*, data_type* grid_ptr, int64_t H, int64_t W, + bool align_corners) { + std::vector w_step(W); + std::vector h_step(H); + Linspace(w_step, W, align_corners); + Linspace(h_step, H, align_corners); + + for (int h = 0; h < H; h++) { + data_type* row_ptr = grid_ptr + h * W * 3; + for (int w = 0; w < W; w++) { + data_type* pixel_ptr = row_ptr + w * 3; + pixel_ptr[0] = w_step[w]; + pixel_ptr[1] = h_step[h]; + pixel_ptr[2] = static_cast(1.0); + } + } + } + + template + static void Generate3D(user_op::KernelComputeContext*, data_type* grid_ptr, int64_t D, int64_t H, + int64_t W, bool align_corners) { + std::vector w_step(W); + std::vector h_step(H); + std::vector d_step(D); + Linspace(w_step, W, align_corners); + Linspace(h_step, H, align_corners); + Linspace(d_step, D, align_corners); + + for (int d = 0; d < D; d++) { + data_type* image_ptr = grid_ptr + d * H * W * 4; + for (int h = 0; h < H; h++) { + data_type* row_ptr = image_ptr + h * W * 4; + for (int w = 0; w < W; w++) { + data_type* pixel_ptr = row_ptr + w * 4; + pixel_ptr[0] = w_step[w]; + pixel_ptr[1] = h_step[h]; + pixel_ptr[2] = d_step[d]; + pixel_ptr[3] = static_cast(1.0); + } + } + } + } +}; + +template<> +struct GenerateBaseGridImp { + static void Generate2D(user_op::KernelComputeContext* ctx, float* grid_ptr, int64_t H, int64_t W, + bool align_corners); + static void Generate2D(user_op::KernelComputeContext* ctx, double* grid_ptr, int64_t H, int64_t W, + bool align_corners); + + static void Generate3D(user_op::KernelComputeContext* ctx, float* grid_ptr, int64_t D, int64_t H, + int64_t W, bool align_corners); + static void Generate3D(user_op::KernelComputeContext* ctx, double* grid_ptr, int64_t D, int64_t H, + int64_t W, bool align_corners); +}; + +} // namespace oneflow + +#endif // _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_ diff --git a/oneflow/user/kernels/grid_sample_kernel.cpp b/oneflow/user/kernels/grid_sample_kernel.cpp new file mode 100644 index 00000000000..b698bfa1988 --- /dev/null +++ b/oneflow/user/kernels/grid_sample_kernel.cpp @@ -0,0 +1,150 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/core/framework/config_def.h" +#include "grid_sample_kernel_util.h" + +namespace oneflow { + +template +class GridSampleKernel final : public user_op::OpKernel { + public: + GridSampleKernel() = default; + ~GridSampleKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); + const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex("grid", 0); + user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); + const std::string interpolation_mode = ctx->Attr("interpolation_mode"); + const std::string padding_mode = ctx->Attr("padding_mode"); + GridSamplerInterpolation interpolation = StringToGridSamplerInterpolation(interpolation_mode); + GridSamplerPadding padding = StringToGridGridSamplerPadding(padding_mode); + const bool align_corners = ctx->Attr("align_corners"); + + const ShapeView& input_shape = input->shape(); + const ShapeView& grid_shape = grid->shape(); + const ShapeView& output_shape = output->shape(); + int64_t count = output_shape.elem_cnt() / input_shape.At(1); + + if (input_shape.NumAxes() == 4) { + if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) { + GridSampleKernelUtil::Forward4D( + ctx, input, grid, output, interpolation, padding, align_corners, input_shape, + grid_shape, output_shape, count); + } else { + GridSampleKernelUtil::Forward4D( + ctx, input, grid, output, interpolation, padding, align_corners, input_shape, + grid_shape, output_shape, count); + } + } else { + if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) { + GridSampleKernelUtil::Forward5D( + ctx, input, grid, output, interpolation, padding, align_corners, input_shape, + grid_shape, output_shape, count); + } else { + GridSampleKernelUtil::Forward5D( + ctx, input, grid, output, interpolation, padding, align_corners, input_shape, + grid_shape, output_shape, count); + } + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_GRID_SAMPLE_KERNEL(device, dtype) \ + REGISTER_USER_KERNEL("grid_sample") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("input", 0) == GetDataType::value)) + +REGISTER_GRID_SAMPLE_KERNEL(DeviceType::kCPU, float); +REGISTER_GRID_SAMPLE_KERNEL(DeviceType::kCPU, double); +#ifdef WITH_CUDA +REGISTER_GRID_SAMPLE_KERNEL(DeviceType::kGPU, float); +REGISTER_GRID_SAMPLE_KERNEL(DeviceType::kGPU, double); +#endif + +template +class GridSampleGradKernel final : public user_op::OpKernel { + public: + GridSampleGradKernel() = default; + ~GridSampleGradKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* doutput = ctx->Tensor4ArgNameAndIndex("doutput", 0); + const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); + const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex("grid", 0); + user_op::Tensor* dinput = ctx->Tensor4ArgNameAndIndex("dinput", 0); + user_op::Tensor* dgrid = ctx->Tensor4ArgNameAndIndex("dgrid", 0); + const std::string interpolation_mode = ctx->Attr("interpolation_mode"); + const std::string padding_mode = ctx->Attr("padding_mode"); + GridSamplerInterpolation interpolation = StringToGridSamplerInterpolation(interpolation_mode); + GridSamplerPadding padding = StringToGridGridSamplerPadding(padding_mode); + const bool align_corners = ctx->Attr("align_corners"); + + const ShapeView& input_shape = input->shape(); + const ShapeView& grid_shape = grid->shape(); + const ShapeView& output_shape = doutput->shape(); + int64_t count = output_shape.elem_cnt() / input_shape.At(1); + + Memset(ctx->device_ctx(), dinput->mut_dptr(), 0, + input_shape.elem_cnt() * sizeof(data_type)); + + if (input_shape.NumAxes() == 4) { + if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) { + GridSampleKernelUtil::Backward4D( + ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners, + input_shape, grid_shape, output_shape, count); + } else { + GridSampleKernelUtil::Backward4D( + ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners, + input_shape, grid_shape, output_shape, count); + } + } else { + if (!CanUse32BitIndex({input_shape, grid_shape, output_shape})) { + GridSampleKernelUtil::Backward5D( + ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners, + input_shape, grid_shape, output_shape, count); + } else { + GridSampleKernelUtil::Backward5D( + ctx, doutput, input, grid, dinput, dgrid, interpolation, padding, align_corners, + input_shape, grid_shape, output_shape, count); + } + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_GRID_SAMPLE_GRAD_KERNEL(device, dtype) \ + REGISTER_USER_KERNEL("grid_sample_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("input", 0) == GetDataType::value)) + +REGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kCPU, float); +REGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kCPU, double); +#ifdef WITH_CUDA +REGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kGPU, float); +REGISTER_GRID_SAMPLE_GRAD_KERNEL(DeviceType::kGPU, double); +#endif + +} // namespace oneflow diff --git a/oneflow/user/kernels/grid_sample_kernel_util.cpp b/oneflow/user/kernels/grid_sample_kernel_util.cpp new file mode 100644 index 00000000000..09ac81fbd3c --- /dev/null +++ b/oneflow/user/kernels/grid_sample_kernel_util.cpp @@ -0,0 +1,79 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "grid_sample_kernel_util.h" + +namespace oneflow { + +template +struct GridSampleKernelUtil final { + static void Forward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, + const user_op::Tensor* grid, user_op::Tensor* output, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { + GridSampler4DKernel( + count, input->dptr(), grid->dptr(), output->mut_dptr(), + input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), + output_shape.At(2), output_shape.At(3), interpolation, padding, align_corners); + } + + static void Forward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, + const user_op::Tensor* grid, user_op::Tensor* output, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { + GridSampler5DKernel( + count, input->dptr(), grid->dptr(), output->mut_dptr(), + input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), + input_shape.At(4), output_shape.At(2), output_shape.At(3), output_shape.At(4), + interpolation, padding, align_corners); + } + + static void Backward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, + const user_op::Tensor* input, const user_op::Tensor* grid, + user_op::Tensor* dinput, user_op::Tensor* dgrid, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, + int64_t count) { + GridSampler4DBackwardKernel( + count, doutput->dptr(), input->dptr(), grid->dptr(), + dinput->mut_dptr(), dgrid->mut_dptr(), input_shape.At(0), + input_shape.At(1), input_shape.At(2), input_shape.At(3), output_shape.At(2), + output_shape.At(3), interpolation, padding, align_corners, input_shape.elem_cnt()); + } + + static void Backward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, + const user_op::Tensor* input, const user_op::Tensor* grid, + user_op::Tensor* dinput, user_op::Tensor* dgrid, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, + int64_t count) { + GridSampler5DBackwardKernel( + count, doutput->dptr(), input->dptr(), grid->dptr(), + dinput->mut_dptr(), dgrid->mut_dptr(), input_shape.At(0), + input_shape.At(1), input_shape.At(2), input_shape.At(3), input_shape.At(4), + output_shape.At(2), output_shape.At(3), output_shape.At(4), interpolation, padding, + align_corners, input_shape.elem_cnt()); + } +}; + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_GRID_SAMPLE_KERNEL_UTIL, (DeviceType::kCPU), + FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/user/kernels/grid_sample_kernel_util.cu b/oneflow/user/kernels/grid_sample_kernel_util.cu new file mode 100644 index 00000000000..b0473e910bd --- /dev/null +++ b/oneflow/user/kernels/grid_sample_kernel_util.cu @@ -0,0 +1,217 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/kernel_util.h" + +#include "grid_sample_kernel_util.h" + +namespace oneflow { + +class CudnnGridSampleDesc final { + public: + OF_DISALLOW_COPY_AND_MOVE(CudnnGridSampleDesc); + CudnnGridSampleDesc(DataType data_type, const ShapeView& shape) { + std::vector tensor_dim({shape.ptr(), shape.ptr() + shape.NumAxes()}); + OF_CUDNN_CHECK(cudnnCreateSpatialTransformerDescriptor(&val_)); + OF_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(val_, CUDNN_SAMPLER_BILINEAR, + GetCudnnDataType(data_type), + shape.NumAxes(), tensor_dim.data())); + } + + ~CudnnGridSampleDesc() { OF_CUDNN_CHECK(cudnnDestroySpatialTransformerDescriptor(val_)); } + + const cudnnSpatialTransformerDescriptor_t& Get() const { return val_; } + + private: + cudnnSpatialTransformerDescriptor_t val_; +}; + +template +struct CudnnGridSampleKernelUtil { + static bool CanRunWithCudnn(user_op::KernelComputeContext* ctx) { + if (ctx->Attr("interpolation_mode") != "bilinear" + || ctx->Attr("padding_mode") != "zeros" || !ctx->Attr("align_corners")) { + return false; + } + const ShapeView& input_shape = ctx->Tensor4ArgNameAndIndex("input", 0)->shape(); + if (input_shape.NumAxes() != 4 || input_shape.At(1) > 1024) { return false; } + + return true; + } + + static void ForwardCompute(user_op::KernelComputeContext* ctx) { + const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); + const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex("grid", 0); + user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); + const ShapeView& input_shape = input->shape(); + const ShapeView& output_shape = output->shape(); + const DataType dtype = input->data_type(); + + CudnnTensorDesc input_desc(dtype, input_shape, "channels_first"); + CudnnTensorDesc output_desc(dtype, output_shape, "channels_first"); + CudnnGridSampleDesc transfomer_desc(dtype, output_shape); + + OF_CUDNN_CHECK(cudnnSpatialTfSamplerForward( + ctx->device_ctx()->cudnn_handle(), transfomer_desc.Get(), CudnnSPOnePtr(), + input_desc.Get(), input->dptr(), grid->dptr(), CudnnSPZeroPtr(), output_desc.Get(), + output->mut_dptr())); + } + + static void BackwardCompute(user_op::KernelComputeContext* ctx) { + const user_op::Tensor* doutput = ctx->Tensor4ArgNameAndIndex("doutput", 0); + const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); + const user_op::Tensor* grid = ctx->Tensor4ArgNameAndIndex("grid", 0); + user_op::Tensor* dinput = ctx->Tensor4ArgNameAndIndex("dinput", 0); + user_op::Tensor* dgrid = ctx->Tensor4ArgNameAndIndex("dgrid", 0); + const ShapeView& input_shape = input->shape(); + const ShapeView& output_shape = doutput->shape(); + const ShapeView& dinput_shape = dinput->shape(); + const DataType dtype = input->data_type(); + + CudnnTensorDesc input_desc(dtype, input_shape, "channels_first"); + CudnnTensorDesc output_desc(dtype, output_shape, "channels_first"); + CudnnTensorDesc dinput_desc(dtype, dinput_shape, "channels_first"); + CudnnGridSampleDesc transfomer_desc(dtype, output_shape); + + OF_CUDNN_CHECK(cudnnSpatialTfSamplerBackward( + ctx->device_ctx()->cudnn_handle(), transfomer_desc.Get(), CudnnSPOnePtr(), + input_desc.Get(), input->dptr(), CudnnSPZeroPtr(), dinput_desc.Get(), dinput->mut_dptr(), + CudnnSPOnePtr(), output_desc.Get(), doutput->dptr(), grid->dptr(), CudnnSPZeroPtr(), + dgrid->mut_dptr())); + } +}; + +template +__launch_bounds__(256) __global__ + void CUDAGridSampler4DKernel(const index_type nthreads, const data_type* input_ptr, + const data_type* grid_ptr, data_type* output_ptr, index_type N, + index_type C, index_type inp_H, index_type inp_W, index_type out_H, + index_type out_W, + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, const bool align_corners) { + GridSampler4DKernel(nthreads, input_ptr, grid_ptr, output_ptr, N, C, inp_H, inp_W, out_H, out_W, + interpolation_mode, padding_mode, align_corners); +} + +template +__launch_bounds__(512) __global__ + void CUDAGridSampler5DKernel(const index_type nthreads, const data_type* input_ptr, + const data_type* grid_ptr, data_type* output_ptr, index_type N, + index_type C, index_type inp_D, index_type inp_H, index_type inp_W, + index_type out_D, index_type out_H, index_type out_W, + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, const bool align_corners) { + GridSampler5DKernel(nthreads, input_ptr, grid_ptr, output_ptr, N, C, inp_D, inp_H, inp_W, out_D, + out_H, out_W, interpolation_mode, padding_mode, align_corners); +} + +template +__launch_bounds__(256) __global__ void CUDAGridSampler4DBackwardKernel( + const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr, + const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N, + index_type C, index_type inp_H, index_type inp_W, index_type out_H, index_type out_W, + const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, + const bool align_corners, const index_type grad_input_memory_span) { + GridSampler4DBackwardKernel(nthreads, grad_output_ptr, input_ptr, grid_ptr, grad_input_ptr, + grad_grid_ptr, N, C, inp_H, inp_W, out_H, out_W, interpolation_mode, + padding_mode, align_corners, grad_input_memory_span); +} + +template +__launch_bounds__(256) __global__ void CUDAGridSampler5DBackwardKernel( + const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr, + const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N, + index_type C, index_type inp_D, index_type inp_H, index_type inp_W, index_type out_D, + index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, const bool align_corners, + const index_type grad_input_memory_span) { + GridSampler5DBackwardKernel(nthreads, grad_output_ptr, input_ptr, grid_ptr, grad_input_ptr, + grad_grid_ptr, N, C, inp_D, inp_H, inp_W, out_D, out_H, out_W, + interpolation_mode, padding_mode, align_corners, + grad_input_memory_span); +} + +template +struct GridSampleKernelUtil final { + static void Forward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, + const user_op::Tensor* grid, user_op::Tensor* output, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { + if (CudnnGridSampleKernelUtil::CanRunWithCudnn(ctx) + && CanUse32BitIndex({input_shape, grid_shape, output_shape})) { + return CudnnGridSampleKernelUtil::ForwardCompute(ctx); + } + + CUDAGridSampler4DKernel + <<device_ctx()->cuda_stream()>>>( + count, input->dptr(), grid->dptr(), output->mut_dptr(), + input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), + output_shape.At(2), output_shape.At(3), interpolation, padding, align_corners); + } + static void Forward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, + const user_op::Tensor* grid, user_op::Tensor* output, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count) { + CUDAGridSampler5DKernel + <<device_ctx()->cuda_stream()>>>( + count, input->dptr(), grid->dptr(), output->mut_dptr(), + input_shape.At(0), input_shape.At(1), input_shape.At(2), input_shape.At(3), + input_shape.At(4), output_shape.At(2), output_shape.At(3), output_shape.At(4), + interpolation, padding, align_corners); + } + + static void Backward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, + const user_op::Tensor* input, const user_op::Tensor* grid, + user_op::Tensor* dinput, user_op::Tensor* dgrid, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, + int64_t count) { + if (CudnnGridSampleKernelUtil::CanRunWithCudnn(ctx) + && CanUse32BitIndex({input_shape, grid_shape, output_shape})) { + return CudnnGridSampleKernelUtil::BackwardCompute(ctx); + } + + CUDAGridSampler4DBackwardKernel + <<device_ctx()->cuda_stream()>>>( + count, doutput->dptr(), input->dptr(), grid->dptr(), + dinput->mut_dptr(), dgrid->mut_dptr(), input_shape.At(0), + input_shape.At(1), input_shape.At(2), input_shape.At(3), output_shape.At(2), + output_shape.At(3), interpolation, padding, align_corners, input_shape.elem_cnt()); + } + static void Backward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, + const user_op::Tensor* input, const user_op::Tensor* grid, + user_op::Tensor* dinput, user_op::Tensor* dgrid, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, + int64_t count) { + CUDAGridSampler5DBackwardKernel + <<device_ctx()->cuda_stream()>>>( + count, doutput->dptr(), input->dptr(), grid->dptr(), + dinput->mut_dptr(), dgrid->mut_dptr(), input_shape.At(0), + input_shape.At(1), input_shape.At(2), input_shape.At(3), input_shape.At(4), + output_shape.At(2), output_shape.At(3), output_shape.At(4), interpolation, padding, + align_corners, input_shape.elem_cnt()); + } +}; + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_GRID_SAMPLE_KERNEL_UTIL, (DeviceType::kGPU), + FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/user/kernels/grid_sample_kernel_util.h b/oneflow/user/kernels/grid_sample_kernel_util.h new file mode 100644 index 00000000000..d4a92356cd3 --- /dev/null +++ b/oneflow/user/kernels/grid_sample_kernel_util.h @@ -0,0 +1,1083 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_GRID_SAMPLE_KERNEL_H_ +#define ONEFLOW_USER_KERNELS_GRID_SAMPLE_KERNEL_H_ + +#include "oneflow/core/common/shape_view.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/framework/op_kernel.h" +#include "oneflow/core/ndarray/xpu_util.h" +#include "oneflow/user/kernels/clip_by_value_kernel.h" +#ifdef WITH_CUDA +#include "oneflow/core/cuda/atomic.cuh" +#endif // WITH_CUDA + +namespace oneflow { + +enum class GridSamplerInterpolation { kBilinear = 0, kNearest, kBicubic }; + +enum class GridSamplerPadding { kZeros = 0, kBorder, kReflection }; + +static GridSamplerInterpolation StringToGridSamplerInterpolation(const std::string& mode) { + if (mode == "bilinear") { + return GridSamplerInterpolation::kBilinear; + } else if (mode == "nearest") { + return GridSamplerInterpolation::kNearest; + } + return GridSamplerInterpolation::kBicubic; +} +static GridSamplerPadding StringToGridGridSamplerPadding(const std::string& mode) { + if (mode == "zeros") { + return GridSamplerPadding::kZeros; + } else if (mode == "border") { + return GridSamplerPadding::kBorder; + } + return GridSamplerPadding::kReflection; +} +static bool CanUse32BitIndex(const std::initializer_list& shapes) { + for (const auto& shape : shapes) { + if (shape.elem_cnt() >= std::numeric_limits::max()) { return false; } + } + return true; +} + +inline int GridSampleGetBlocks(const int64_t number, const int64_t threads_per_block) { + // Round up division for positive number that cannot cause integer overflow + auto block_num = (number - 1) / threads_per_block + 1; + return static_cast(block_num); +} + +// This kernel implement is referenced from: +// https://github.com/pytorch/pytorch with git commit id: e7724bb +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cu +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cuh + +// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, +// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). +// if align_corners: -1 and +1 get sent to the centers of the corner pixels +// -1 --> 0 +// +1 --> (size - 1) +// scale_factor = (size - 1) / 2 +// if not align_corners: -1 and +1 get sent to the image edges +// -1 --> -0.5 +// +1 --> (size - 1) + 0.5 == size - 0.5 +// scale_factor = size / 2 +template +static OF_DEVICE_FUNC scalar_t GridSamplerUnnormalize(scalar_t coord, int size, + bool align_corners) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1.f) * size - 1) / 2; + } +} + +// GridSamplerUnnormalizeSetGrad works the same as GridSamplerUnnormalize +// except that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static OF_DEVICE_FUNC scalar_t GridSamplerUnnormalizeSetGrad(scalar_t coord, int size, + bool align_corners, + scalar_t* grad_in) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + *grad_in = static_cast(size - 1) / 2; + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + *grad_in = static_cast(size) / 2; + return ((coord + 1.f) * size - 1) / 2; + } +} + +// Clips coordinates to between 0 and clip_limit - 1 +template +static OF_DEVICE_FUNC scalar_t ClipCoordinates(scalar_t in, int clip_limit) { + return DeviceMin(static_cast(clip_limit - 1), DeviceMax(in, static_cast(0))); +} + +// ClipCoordinatesSetGrad works similarly to ClipCoordinates except that +// it also returns the `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static OF_DEVICE_FUNC scalar_t ClipCoordinatesSetGrad(scalar_t in, int clip_limit, + scalar_t* grad_in) { + // Note that it is important for the gradient calculation that borders + // are considered out of bounds. + if (in <= static_cast(0)) { + *grad_in = static_cast(0); + return static_cast(0); + } else { + scalar_t max = static_cast(clip_limit - 1); + if (in >= max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +// Reflects coordinates until they fall between low and high (inclusive). +// The bounds are passed as twice their value so that half-integer values +// can be represented as ints. +template +static OF_DEVICE_FUNC scalar_t ReflectCoordinates(scalar_t in, int twice_low, int twice_high) { + if (twice_low == twice_high) { return static_cast(0); } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = fmod(in, span); + int flips = static_cast(floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +// ReflectCoordinatesSetGrad works similarly to ReflectCoordinates except +// that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static OF_DEVICE_FUNC scalar_t ReflectCoordinatesSetGrad(scalar_t in, int twice_low, int twice_high, + scalar_t* grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast(0); + return static_cast(0); + } + int grad_in_mult_ = 1; + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + // `fmod` returns same sign as `in`, which is positive after the `if` above. + scalar_t extra = fmod(in, span); + int flips = static_cast(floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast(-grad_in_mult_); + return span - extra + min; + } +} + +#if defined(__CUDACC__) +template +static __device__ __forceinline__ scalar_t safe_downgrade_to_int_range(scalar_t x) { + // -100.0 does not have special meaning. This is just to make sure + // it's not WithinBounds2D or WithinBounds3D, and does not cause + // undefined behavior. See #35506. + if (x > INT_MAX - 1 || x < INT_MIN || !isfinite(static_cast(x))) + return static_cast(-100.0); + return x; +} +#endif + +template +static OF_DEVICE_FUNC scalar_t ComputeCoordinates(scalar_t coord, int size, + GridSamplerPadding padding_mode, + bool align_corners) { + if (padding_mode == GridSamplerPadding::kBorder) { + // clip coordinates to image borders + coord = ClipCoordinates(coord, size); + } else if (padding_mode == GridSamplerPadding::kReflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = ReflectCoordinates(coord, 0, 2 * (size - 1)); + } else { + coord = ReflectCoordinates(coord, -1, 2 * size - 1); + } + // clip coordinates to image borders + coord = ClipCoordinates(coord, size); + } +#if defined(__CUDACC__) + coord = safe_downgrade_to_int_range(coord); +#endif + return coord; +} + +// Computes the pixel source index value for a grid coordinate +template +static OF_DEVICE_FUNC scalar_t GridSamplerComputeSourceIndex(scalar_t coord, int size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = GridSamplerUnnormalize(coord, size, align_corners); + coord = ComputeCoordinates(coord, size, padding_mode, align_corners); + return coord; +} + +// GridSamplerComputeSourceIndexSetGrad works similarly to +// GridSamplerComputeSourceIndex except that it also returns the +// `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static OF_DEVICE_FUNC scalar_t GridSamplerComputeSourceIndexSetGrad(scalar_t coord, int size, + GridSamplerPadding padding_mode, + bool align_corners, + scalar_t* grad_in) { + scalar_t grad_clip, grad_refl; + coord = GridSamplerUnnormalizeSetGrad(coord, size, align_corners, grad_in); + if (padding_mode == GridSamplerPadding::kBorder) { + // clip coordinates to image borders + coord = ClipCoordinatesSetGrad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == GridSamplerPadding::kReflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = ReflectCoordinatesSetGrad(coord, 0, 2 * (size - 1), &grad_refl); + } else { + coord = ReflectCoordinatesSetGrad(coord, -1, 2 * size - 1, &grad_refl); + } + // clip coordinates to image borders + coord = ClipCoordinatesSetGrad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + +#if defined(__CUDACC__) + coord = safe_downgrade_to_int_range(coord); +#endif + return coord; +} + +static OF_DEVICE_FUNC bool WithinBounds2D(int h, int w, int H, int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +static OF_DEVICE_FUNC bool WithinBounds3D(int d, int h, int w, int D, int H, int W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +template +static OF_DEVICE_FUNC scalar_t GetValueBounded(const scalar_t* data, scalar_t x, scalar_t y, int W, + int H, int sW, int sH, + GridSamplerPadding padding_mode, + bool align_corners) { + x = ComputeCoordinates(x, W, padding_mode, align_corners); + y = ComputeCoordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + if (WithinBounds2D(iy, ix, H, W)) { return data[iy * sH + ix * sW]; } + return static_cast(0); +} + +template +static OF_DEVICE_FUNC void SafeAdd2D(scalar_t* data, int h, int w, int sH, int sW, int H, int W, + scalar_t delta, const index_t NC_offset, + const index_t memory_span) { + if (WithinBounds2D(h, w, H, W)) { +#if defined(__CUDACC__) + cuda::atomic::Add(data + NC_offset + h * sH + w * sW, delta); +#else + data[NC_offset + h * sH + w * sW] += delta; +#endif + } +} + +template +static OF_DEVICE_FUNC void SafeAdd3D(scalar_t* data, int d, int h, int w, int sD, int sH, int sW, + int D, int H, int W, scalar_t delta, const index_t NC_offset, + const index_t memory_span) { + if (WithinBounds3D(d, h, w, D, H, W)) { +#if defined(__CUDACC__) + cuda::atomic::Add(data + NC_offset + d * sD + h * sH + w * sW, delta); +#else + data[NC_offset + d * sD + h * sH + w * sW] += delta; +#endif + } +} + +template +static OF_DEVICE_FUNC void AddValueBounded(scalar_t* data, scalar_t x, scalar_t y, int W, int H, + int sW, int sH, scalar_t delta, + GridSamplerPadding padding_mode, bool align_corners, + const index_t NC_offset, const index_t memory_span) { + x = ComputeCoordinates(x, W, padding_mode, align_corners); + y = ComputeCoordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + SafeAdd2D(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +static OF_DEVICE_FUNC void GetCubicCoefficientsGrad(scalar_t coeffs[4], scalar_t t) { + // Must be the same as forward calculation in + // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + +// Based on +// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +template +OF_DEVICE_FUNC static accscalar_t CubicConvolution1(accscalar_t x, accscalar_t A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +OF_DEVICE_FUNC static accscalar_t CubicConvolution2(accscalar_t x, accscalar_t A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +template +OF_DEVICE_FUNC static void GetCubicUpsamplingCoefficients(accscalar_t coeffs[4], accscalar_t t) { + accscalar_t A = -0.75; + + accscalar_t x1 = t; + coeffs[0] = CubicConvolution2(x1 + 1.0, A); + coeffs[1] = CubicConvolution1(x1, A); + + // opposite coefficients + accscalar_t x2 = 1.0 - t; + coeffs[2] = CubicConvolution1(x2, A); + coeffs[3] = CubicConvolution2(x2 + 1.0, A); +} + +template +OF_DEVICE_FUNC static accscalar_t cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, + accscalar_t t) { + accscalar_t coeffs[4]; + GetCubicUpsamplingCoefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +template +OF_DEVICE_FUNC void GridSampler4DKernel(const index_type nthreads, const data_type* input_ptr, + const data_type* grid_ptr, data_type* output_ptr, + index_type N, index_type C, index_type inp_H, + index_type inp_W, index_type out_H, index_type out_W, + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, + const bool align_corners) { + index_type inp_sN = C * inp_H * inp_W; + index_type inp_sC = inp_H * inp_W; + index_type inp_sH = inp_W; + index_type inp_sW = 1; + index_type grid_sN = out_H * out_W * 2; + index_type grid_sH = out_W * 2; + index_type grid_sW = 2; + index_type grid_sCoor = 1; + index_type out_sN = C * out_H * out_W; + index_type out_sC = out_H * out_W; + index_type out_sH = out_W; + index_type out_sW = 1; + + XPU_1D_KERNEL_LOOP(index, nthreads) { + const index_type w = index % out_W; + const index_type h = (index / out_W) % out_H; + const index_type n = index / (out_H * out_W); + const index_type grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + data_type x = grid_ptr[grid_offset]; + data_type y = grid_ptr[grid_offset + grid_sCoor]; + + data_type ix = GridSamplerComputeSourceIndex(x, inp_W, padding_mode, align_corners); + data_type iy = GridSamplerComputeSourceIndex(y, inp_H, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::kBilinear) { + // get NE, NW, SE, SW pixel values from (x, y) + index_type ix_nw = static_cast(::floor(ix)); + index_type iy_nw = static_cast(::floor(iy)); + index_type ix_ne = ix_nw + 1; + index_type iy_ne = iy_nw; + index_type ix_sw = ix_nw; + index_type iy_sw = iy_nw + 1; + index_type ix_se = ix_nw + 1; + index_type iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + data_type nw = (ix_se - ix) * (iy_se - iy); + data_type ne = (ix - ix_sw) * (iy_sw - iy); + data_type sw = (ix_ne - ix) * (iy - iy_ne); + data_type se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + auto inp_ptr_NC = input_ptr + n * inp_sN; + auto out_ptr_NCHW = output_ptr + n * out_sN + h * out_sH + w * out_sW; + for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + *out_ptr_NCHW = static_cast(0); + if (WithinBounds2D(iy_nw, ix_nw, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (WithinBounds2D(iy_ne, ix_ne, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (WithinBounds2D(iy_sw, ix_sw, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (WithinBounds2D(iy_se, ix_se, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + } + } + } else if (interpolation_mode == GridSamplerInterpolation::kNearest) { + index_type ix_nearest = static_cast(::round(ix)); + index_type iy_nearest = static_cast(::round(iy)); + + // assign nearest neighor pixel value to output pixel + auto inp_ptr_NC = input_ptr + n * inp_sN; + auto out_ptr_NCHW = output_ptr + n * out_sN + h * out_sH + w * out_sW; + for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + if (WithinBounds2D(iy_nearest, ix_nearest, inp_H, inp_W)) { + *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCHW = static_cast(0); + } + } + } else if (interpolation_mode == GridSamplerInterpolation::kBicubic) { + ix = GridSamplerUnnormalize(x, inp_W, align_corners); + iy = GridSamplerUnnormalize(y, inp_H, align_corners); + + data_type ix_nw = ::floor(ix); + data_type iy_nw = ::floor(iy); + + const data_type tx = ix - ix_nw; + const data_type ty = iy - iy_nw; + + auto inp_ptr_NC = input_ptr + n * inp_sN; + auto out_ptr_NCHW = output_ptr + n * out_sN + h * out_sH + w * out_sW; + for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + data_type coefficients[4]; +#ifdef __CUDA_ARCH__ +#pragma unroll 4 +#endif + for (index_type i = 0; i < 4; ++i) { + coefficients[i] = cubic_interp1d( + GetValueBounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, + inp_sH, padding_mode, align_corners), + GetValueBounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, + inp_sH, padding_mode, align_corners), + GetValueBounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, + inp_sH, padding_mode, align_corners), + GetValueBounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, + inp_sH, padding_mode, align_corners), + tx); + } + + *out_ptr_NCHW = + cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], coefficients[3], ty); + } + } + } +} + +template +OF_DEVICE_FUNC void GridSampler5DKernel(const index_type nthreads, const data_type* input_ptr, + const data_type* grid_ptr, data_type* output_ptr, + index_type N, index_type C, index_type inp_D, + index_type inp_H, index_type inp_W, index_type out_D, + index_type out_H, index_type out_W, + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, + const bool align_corners) { + index_type inp_sN = C * inp_D * inp_H * inp_W; + index_type inp_sC = inp_D * inp_H * inp_W; + index_type inp_sD = inp_H * inp_W; + index_type inp_sH = inp_W; + index_type inp_sW = 1; + index_type grid_sN = out_D * out_H * out_W * 3; + index_type grid_sD = out_H * out_W * 3; + index_type grid_sH = out_W * 3; + index_type grid_sW = 3; + index_type grid_sCoor = 1; + index_type out_sN = C * out_D * out_H * out_W; + index_type out_sC = out_D * out_H * out_W; + index_type out_sD = out_H * out_W; + index_type out_sH = out_W; + index_type out_sW = 1; + + XPU_1D_KERNEL_LOOP(index, nthreads) { + const index_type w = index % out_W; + const index_type h = (index / out_W) % out_H; + const index_type d = (index / (out_H * out_W)) % out_D; + const index_type n = index / (out_D * out_H * out_W); + const index_type grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z co-ordinates from grid + data_type ix = grid_ptr[grid_offset]; + data_type iy = grid_ptr[grid_offset + grid_sCoor]; + data_type iz = grid_ptr[grid_offset + 2 * grid_sCoor]; + + ix = GridSamplerComputeSourceIndex(ix, inp_W, padding_mode, align_corners); + iy = GridSamplerComputeSourceIndex(iy, inp_H, padding_mode, align_corners); + iz = GridSamplerComputeSourceIndex(iz, inp_D, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::kBilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_type ix_tnw = static_cast(::floor(ix)); + index_type iy_tnw = static_cast(::floor(iy)); + index_type iz_tnw = static_cast(::floor(iz)); + + index_type ix_tne = ix_tnw + 1; + index_type iy_tne = iy_tnw; + index_type iz_tne = iz_tnw; + + index_type ix_tsw = ix_tnw; + index_type iy_tsw = iy_tnw + 1; + index_type iz_tsw = iz_tnw; + + index_type ix_tse = ix_tnw + 1; + index_type iy_tse = iy_tnw + 1; + index_type iz_tse = iz_tnw; + + index_type ix_bnw = ix_tnw; + index_type iy_bnw = iy_tnw; + index_type iz_bnw = iz_tnw + 1; + + index_type ix_bne = ix_tnw + 1; + index_type iy_bne = iy_tnw; + index_type iz_bne = iz_tnw + 1; + + index_type ix_bsw = ix_tnw; + index_type iy_bsw = iy_tnw + 1; + index_type iz_bsw = iz_tnw + 1; + + index_type ix_bse = ix_tnw + 1; + index_type iy_bse = iy_tnw + 1; + index_type iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + data_type tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + data_type tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + data_type tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + data_type tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + data_type bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + data_type bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + data_type bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + data_type bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input_ptr + n * inp_sN; + auto out_ptr_NCDHW = output_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse + *out_ptr_NCDHW = static_cast(0); + if (WithinBounds3D(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; + } + if (WithinBounds3D(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; + } + if (WithinBounds3D(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + } + if (WithinBounds3D(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + } + if (WithinBounds3D(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; + } + if (WithinBounds3D(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; + } + if (WithinBounds3D(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; + } + if (WithinBounds3D(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; + } + } + } else if (interpolation_mode == GridSamplerInterpolation::kNearest) { + index_type ix_nearest = static_cast(::round(ix)); + index_type iy_nearest = static_cast(::round(iy)); + index_type iz_nearest = static_cast(::round(iz)); + + // assign nearest neighor pixel value to output pixel + auto inp_ptr_NC = input_ptr + n * inp_sN; + auto out_ptr_NCDHW = output_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (index_type c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + if (WithinBounds3D(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW = + inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCDHW = static_cast(0); + } + } + } + } +} + +// Note [Passing pointer and offset to fastAtomicAdd] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// For its internal bounds checking, fastAtomicAdd needs to know where the destination address +// lies relative to the entire tensor, so we pass the base grad_input_ptr and full offset +// information, including batch * channel offset (NC_offset). + +template +OF_DEVICE_FUNC void GridSampler4DBackwardKernel( + const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr, + const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N, + index_type C, index_type inp_H, index_type inp_W, index_type out_H, index_type out_W, + const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, + const bool align_corners, const index_type grad_input_memory_span) { + index_type inp_sN = C * inp_H * inp_W; + index_type inp_sC = inp_H * inp_W; + index_type inp_sH = inp_W; + index_type inp_sW = 1; + index_type grid_sN = out_H * out_W * 2; + index_type grid_sH = out_W * 2; + index_type grid_sW = 2; + index_type grid_sCoor = 1; + index_type gOut_sN = C * out_H * out_W; + index_type gOut_sC = out_H * out_W; + index_type gOut_sH = out_W; + index_type gOut_sW = 1; + index_type gInp_sN = inp_sN; + index_type gInp_sC = inp_sC; + index_type gInp_sH = inp_sH; + index_type gInp_sW = inp_sW; + index_type gGrid_sW = grid_sW; + + XPU_1D_KERNEL_LOOP(index, nthreads) { + const index_type w = index % out_W; + const index_type h = (index / out_W) % out_H; + const index_type n = index / (out_H * out_W); + const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + data_type x = grid_ptr[grid_offset]; + data_type y = grid_ptr[grid_offset + grid_sCoor]; + + // multipliers for gradients on ix and iy + data_type gix_mult, giy_mult; + data_type ix = + GridSamplerComputeSourceIndexSetGrad(x, inp_W, padding_mode, align_corners, &gix_mult); + data_type iy = + GridSamplerComputeSourceIndexSetGrad(y, inp_H, padding_mode, align_corners, &giy_mult); + + if (interpolation_mode == GridSamplerInterpolation::kBilinear) { + // get NE, NW, SE, SW pixel values from (x, y) + index_type ix_nw = static_cast(::floor(ix)); + index_type iy_nw = static_cast(::floor(iy)); + index_type ix_ne = ix_nw + 1; + index_type iy_ne = iy_nw; + index_type ix_sw = ix_nw; + index_type iy_sw = iy_nw + 1; + index_type ix_se = ix_nw + 1; + index_type iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + data_type nw = (ix_se - ix) * (iy_se - iy); + data_type ne = (ix - ix_sw) * (iy_sw - iy); + data_type sw = (ix_ne - ix) * (iy - iy_ne); + data_type se = (ix - ix_nw) * (iy - iy_nw); + + data_type gix = static_cast(0), giy = static_cast(0); + const data_type* gOut_ptr_NCHW = grad_output_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; + index_type NC_offset = n * gInp_sN; + const data_type* inp_ptr_NC = input_ptr + n * inp_sN; + for (index_type c = 0; c < C; + ++c, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) { + data_type gOut = *gOut_ptr_NCHW; + + // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. + SafeAdd2D(grad_input_ptr, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut, + NC_offset, grad_input_memory_span); + SafeAdd2D(grad_input_ptr, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut, + NC_offset, grad_input_memory_span); + SafeAdd2D(grad_input_ptr, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut, + NC_offset, grad_input_memory_span); + SafeAdd2D(grad_input_ptr, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut, + NC_offset, grad_input_memory_span); + + // calculate grad_grid + if (WithinBounds2D(iy_nw, ix_nw, inp_H, inp_W)) { + data_type nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]; + gix -= nw_val * (iy_se - iy) * gOut; + giy -= nw_val * (ix_se - ix) * gOut; + } + if (WithinBounds2D(iy_ne, ix_ne, inp_H, inp_W)) { + data_type ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]; + gix += ne_val * (iy_sw - iy) * gOut; + giy -= ne_val * (ix - ix_sw) * gOut; + } + if (WithinBounds2D(iy_sw, ix_sw, inp_H, inp_W)) { + data_type sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]; + gix -= sw_val * (iy - iy_ne) * gOut; + giy += sw_val * (ix_ne - ix) * gOut; + } + if (WithinBounds2D(iy_se, ix_se, inp_H, inp_W)) { + data_type se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]; + gix += se_val * (iy - iy_nw) * gOut; + giy += se_val * (ix - ix_nw) * gOut; + } + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW + // 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1] + data_type* gGrid_ptr_NHW = grad_grid_ptr + index * gGrid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; + } else if (interpolation_mode == GridSamplerInterpolation::kNearest) { + index_type ix_nearest = static_cast(::round(ix)); + index_type iy_nearest = static_cast(::round(iy)); + + // assign nearest neighor pixel value to output pixel + const data_type* gOut_ptr_NCHW = grad_output_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; + index_type NC_offset = n * gInp_sN; + for (index_type c = 0; c < C; ++c, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) { + // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. + SafeAdd2D(grad_input_ptr, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W, + *gOut_ptr_NCHW, NC_offset, grad_input_memory_span); + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW + // 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1] + data_type* gGrid_ptr_NHW = grad_grid_ptr + index * gGrid_sW; + gGrid_ptr_NHW[0] = static_cast(0); + gGrid_ptr_NHW[1] = static_cast(0); + } else if (interpolation_mode == GridSamplerInterpolation::kBicubic) { + ix = GridSamplerUnnormalizeSetGrad(x, inp_W, align_corners, &gix_mult); + iy = GridSamplerUnnormalizeSetGrad(y, inp_H, align_corners, &giy_mult); + + data_type ix_nw = ::floor(ix); + data_type iy_nw = ::floor(iy); + + const data_type tx = ix - ix_nw; + const data_type ty = iy - iy_nw; + + data_type x_coeffs[4]; + data_type y_coeffs[4]; + data_type x_coeffs_grad[4]; + data_type y_coeffs_grad[4]; + + GetCubicUpsamplingCoefficients(x_coeffs, tx); + GetCubicUpsamplingCoefficients(y_coeffs, ty); + GetCubicCoefficientsGrad(x_coeffs_grad, tx); + GetCubicCoefficientsGrad(y_coeffs_grad, ty); + + data_type gix = static_cast(0); + data_type giy = static_cast(0); + + const data_type* gOut_ptr_NCHW = grad_output_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; + index_type NC_offset = n * gInp_sN; + const data_type* inp_ptr_NC = input_ptr + n * inp_sN; + + for (index_type c = 0; c < C; + ++c, gOut_ptr_NCHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) { + data_type gOut = *gOut_ptr_NCHW; + +#ifdef __CUDA_ARCH__ +#pragma unroll 4 +#endif + for (index_type i = 0; i < 4; ++i) { +#ifdef __CUDA_ARCH__ +#pragma unroll 4 +#endif + for (index_type j = 0; j < 4; ++j) { + // set input gradient. See Note [Passing pointer and offset to fastAtomicAdd]. + AddValueBounded(grad_input_ptr, ix_nw - 1 + i, iy_nw - 1 + j, inp_W, inp_H, + gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], + padding_mode, align_corners, NC_offset, + grad_input_memory_span); + + // set grid gradient + data_type val = + GetValueBounded(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, inp_W, inp_H, + inp_sW, inp_sH, padding_mode, align_corners); + + gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut; + giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut; + } + } + } + + data_type* gGrid_ptr_NHW = grad_grid_ptr + index * gGrid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; + } + } +} + +template +OF_DEVICE_FUNC void GridSampler5DBackwardKernel( + const index_type nthreads, const data_type* grad_output_ptr, const data_type* input_ptr, + const data_type* grid_ptr, data_type* grad_input_ptr, data_type* grad_grid_ptr, index_type N, + index_type C, index_type inp_D, index_type inp_H, index_type inp_W, index_type out_D, + index_type out_H, index_type out_W, const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, const bool align_corners, + const index_type grad_input_memory_span) { + index_type inp_sN = C * inp_D * inp_H * inp_W; + index_type inp_sC = inp_D * inp_H * inp_W; + index_type inp_sD = inp_H * inp_W; + index_type inp_sH = inp_W; + index_type inp_sW = 1; + index_type grid_sN = out_D * out_H * out_W * 3; + index_type grid_sD = out_H * out_W * 3; + index_type grid_sH = out_W * 3; + index_type grid_sW = 3; + index_type grid_sCoor = 1; + index_type gOut_sN = C * out_D * out_H * out_W; + index_type gOut_sC = out_D * out_H * out_W; + index_type gOut_sD = out_H * out_W; + index_type gOut_sH = out_W; + index_type gOut_sW = 1; + index_type gInp_sN = inp_sN; + index_type gInp_sC = inp_sC; + index_type gInp_sD = inp_sD; + index_type gInp_sH = inp_sH; + index_type gInp_sW = inp_sW; + index_type gGrid_sW = grid_sW; + + XPU_1D_KERNEL_LOOP(index, nthreads) { + const index_type w = index % out_W; + const index_type h = (index / out_W) % out_H; + const index_type d = (index / (out_H * out_W)) % out_D; + const index_type n = index / (out_D * out_H * out_W); + const auto grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z co-ordinates from grid + data_type ix = grid_ptr[grid_offset]; + data_type iy = grid_ptr[grid_offset + grid_sCoor]; + data_type iz = grid_ptr[grid_offset + 2 * grid_sCoor]; + + // multipliers for gradients on ix, iy, and iz + data_type gix_mult, giy_mult, giz_mult; + ix = GridSamplerComputeSourceIndexSetGrad(ix, inp_W, padding_mode, align_corners, &gix_mult); + iy = GridSamplerComputeSourceIndexSetGrad(iy, inp_H, padding_mode, align_corners, &giy_mult); + iz = GridSamplerComputeSourceIndexSetGrad(iz, inp_D, padding_mode, align_corners, &giz_mult); + + if (interpolation_mode == GridSamplerInterpolation::kBilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_type ix_tnw = static_cast(::floor(ix)); + index_type iy_tnw = static_cast(::floor(iy)); + index_type iz_tnw = static_cast(::floor(iz)); + + index_type ix_tne = ix_tnw + 1; + index_type iy_tne = iy_tnw; + index_type iz_tne = iz_tnw; + + index_type ix_tsw = ix_tnw; + index_type iy_tsw = iy_tnw + 1; + index_type iz_tsw = iz_tnw; + + index_type ix_tse = ix_tnw + 1; + index_type iy_tse = iy_tnw + 1; + index_type iz_tse = iz_tnw; + + index_type ix_bnw = ix_tnw; + index_type iy_bnw = iy_tnw; + index_type iz_bnw = iz_tnw + 1; + + index_type ix_bne = ix_tnw + 1; + index_type iy_bne = iy_tnw; + index_type iz_bne = iz_tnw + 1; + + index_type ix_bsw = ix_tnw; + index_type iy_bsw = iy_tnw + 1; + index_type iz_bsw = iz_tnw + 1; + + index_type ix_bse = ix_tnw + 1; + index_type iy_bse = iy_tnw + 1; + index_type iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + data_type tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + data_type tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + data_type tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + data_type tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + data_type bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + data_type bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + data_type bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + data_type bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + data_type gix = static_cast(0), giy = static_cast(0), + giz = static_cast(0); + const data_type* gOut_ptr_NCDHW = + grad_output_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + index_type NC_offset = n * gInp_sN; + const data_type* inp_ptr_NC = input_ptr + n * inp_sN; + // calculate bilinear weighted pixel value and set output pixel + for (index_type c = 0; c < C; + ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) { + data_type gOut = *gOut_ptr_NCDHW; + + // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. + SafeAdd3D(grad_input_ptr, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, + inp_W, tnw * gOut, NC_offset, grad_input_memory_span); + SafeAdd3D(grad_input_ptr, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, + inp_W, tne * gOut, NC_offset, grad_input_memory_span); + SafeAdd3D(grad_input_ptr, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, + inp_W, tsw * gOut, NC_offset, grad_input_memory_span); + SafeAdd3D(grad_input_ptr, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, + inp_W, tse * gOut, NC_offset, grad_input_memory_span); + SafeAdd3D(grad_input_ptr, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, + inp_W, bnw * gOut, NC_offset, grad_input_memory_span); + SafeAdd3D(grad_input_ptr, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, + inp_W, bne * gOut, NC_offset, grad_input_memory_span); + SafeAdd3D(grad_input_ptr, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, + inp_W, bsw * gOut, NC_offset, grad_input_memory_span); + SafeAdd3D(grad_input_ptr, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, + inp_W, bse * gOut, NC_offset, grad_input_memory_span); + + // calculate grad_grid + if (WithinBounds3D(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + data_type tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; + giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; + giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; + } + if (WithinBounds3D(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + data_type tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; + giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; + giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; + } + if (WithinBounds3D(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + data_type tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; + giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; + giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; + } + if (WithinBounds3D(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + data_type tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; + giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; + giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; + } + if (WithinBounds3D(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + data_type bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; + giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; + giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; + } + if (WithinBounds3D(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + data_type bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; + giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; + giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; + } + if (WithinBounds3D(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + data_type bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; + giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; + giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; + } + if (WithinBounds3D(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + data_type bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; + giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; + giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; + } + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] + data_type* gGrid_ptr_NDHW = grad_grid_ptr + index * gGrid_sW; + gGrid_ptr_NDHW[0] = gix_mult * gix; + gGrid_ptr_NDHW[1] = giy_mult * giy; + gGrid_ptr_NDHW[2] = giz_mult * giz; + } else if (interpolation_mode == GridSamplerInterpolation::kNearest) { + auto ix_nearest = static_cast(::round(ix)); + auto iy_nearest = static_cast(::round(iy)); + auto iz_nearest = static_cast(::round(iz)); + + // assign nearest neighor pixel value to output pixel + const data_type* gOut_ptr_NCDHW = + grad_output_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + index_type NC_offset = n * gInp_sN; + for (index_type c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC) { + // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. + SafeAdd3D(grad_input_ptr, iz_nearest, iy_nearest, ix_nearest, gInp_sD, gInp_sH, gInp_sW, + inp_D, inp_H, inp_W, *gOut_ptr_NCDHW, NC_offset, grad_input_memory_span); + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] + data_type* gGrid_ptr_NDHW = grad_grid_ptr + index * gGrid_sW; + gGrid_ptr_NDHW[0] = static_cast(0); + gGrid_ptr_NDHW[1] = static_cast(0); + gGrid_ptr_NDHW[2] = static_cast(0); + } + } +} + +template +struct GridSampleKernelUtil final { + static void Forward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, + const user_op::Tensor* grid, user_op::Tensor* output, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count); + static void Forward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* input, + const user_op::Tensor* grid, user_op::Tensor* output, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count); + + static void Backward4D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, + const user_op::Tensor* input, const user_op::Tensor* grid, + user_op::Tensor* dinput, user_op::Tensor* dgrid, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count); + static void Backward5D(user_op::KernelComputeContext* ctx, const user_op::Tensor* doutput, + const user_op::Tensor* input, const user_op::Tensor* grid, + user_op::Tensor* dinput, user_op::Tensor* dgrid, + GridSamplerInterpolation interpolation, GridSamplerPadding padding, + const bool align_corners, const ShapeView& input_shape, + const ShapeView& grid_shape, const ShapeView& output_shape, int64_t count); +}; + +// macros for functors instantiate(used by grid_sample_kernel_util.cu, grid_sample_kernel_util.cpp) +#define INSTANTIATE_GRID_SAMPLE_KERNEL_UTIL(device_type, dtype_pair, itype_pair) \ + template struct GridSampleKernelUtil; + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_GRID_SAMPLE_KERNEL_H_ diff --git a/oneflow/user/ops/affine_grid_op.cpp b/oneflow/user/ops/affine_grid_op.cpp new file mode 100644 index 00000000000..449cfe21560 --- /dev/null +++ b/oneflow/user/ops/affine_grid_op.cpp @@ -0,0 +1,149 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +namespace { + +Maybe CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + bool pass_checked = true; + std::stringstream err; + err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; + + const auto& size = conf.attr("size"); + if (size.NumAxes() != 4 && size.NumAxes() != 5) { + err << "dimension of size can't be:" << size.NumAxes(); + pass_checked = false; + } + + for (int i = 0; i < size.NumAxes(); i++) { + if (size.At(i) <= 0) { err << "element of size can't be:" << size.At(i); } + } + + if (pass_checked) { + return Maybe::Ok(); + } else { + return oneflow::Error::CheckFailedError() << err.str(); + } +} + +} // namespace + +REGISTER_USER_OP("affine_grid") + .Input("theta") + .Output("grid") + .Attr("size") + .Attr("align_corners") + .SetCheckAttrFn(CheckAttr) + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); + user_op::TensorDesc* grid = ctx->OutputTensorDesc("grid", 0); + const Shape& size = ctx->Attr("size"); + // Only support 2D or 3D affine grid with NCHW layout + // For 2D grid: theta = { N, 2, 3 }, + // size = { N, C, H, W } + // grid = { N, H, W, 2 } + // For 3D grid: theta = { N, 3, 4 }, + // size = { N, C, D, H, W } + // grid = { N, D, H, W, 3 } + bool is_2d_grid = true; + if (theta.shape().At(1) == 2) { + CHECK_EQ_OR_RETURN(theta.shape().At(2), 3) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; + CHECK_EQ_OR_RETURN(size.NumAxes(), 4) << "Dimension of size MUST be 4, when 2d affine grid"; + CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) + << "Theta and size MUST have same batch dimension"; + is_2d_grid = true; + } else if (theta.shape().At(1) == 3) { + CHECK_EQ_OR_RETURN(theta.shape().At(2), 4) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; + CHECK_EQ_OR_RETURN(size.NumAxes(), 5) "Dimension of size MUST be 4, when 3d affine grid"; + CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) + << "Theta and size MUST have same batch dimension"; + is_2d_grid = false; + } else { + CHECK_OR_RETURN(false) << "Theta MUST be 2D or 3D grid"; + } + *grid->mut_is_dynamic() = theta.is_dynamic(); + Shape& grid_shape = *grid->mut_shape(); + if (is_2d_grid) { + grid_shape = {size.At(0), size.At(2), size.At(3), 2}; + } else { + grid_shape = {size.At(0), size.At(2), size.At(3), size.At(4), 3}; + } + return Maybe::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("theta", 0), 0) + .Split(user_op::OpArg("grid", 0), 0) + .Build(); + return Maybe::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("grid", 0) = ctx->InputDType("theta", 0); + return Maybe::Ok(); + }); + +REGISTER_USER_OP("affine_grid_grad") + .Input("dgrid") + .Output("dtheta") + .Attr("size") + .Attr("align_corners") + .SetCheckAttrFn(CheckAttr) + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { + const Shape& size = ctx->Attr("size"); + + if (size.NumAxes() == 4) { + *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 2, 3}; + } else if (size.NumAxes() == 5) { + *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 3, 4}; + } else { + CHECK_OR_RETURN(false) << "size MUST be 4D or 5D"; + } + return Maybe::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("dgrid", 0), 0) + .Split(user_op::OpArg("dtheta", 0), 0) + .Build(); + return Maybe::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("dtheta", 0) = ctx->InputDType("dgrid", 0); + return Maybe::Ok(); + }); + +REGISTER_USER_OP_GRAD("affine_grid") + .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, + const user_op::AddOpFn& AddOp) -> Maybe { + if (op.NeedGenGradTensor4OpInput("theta", 0)) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + user_op::UserOpConfWrapper grad_op = + builder.Op("affine_grid_grad") + .Input("dgrid", op.GetGradTensorWithOpOutput("grid", 0)) + .Output("dtheta") + .Attr("size", op.attr("size")) + .Attr("align_corners", op.attr("align_corners")) + .Build(); + op.BindGradTensorWithOpInput(grad_op.output("dtheta", 0), "theta", 0); + AddOp(grad_op); + } + return Maybe::Ok(); + }); + +} // namespace oneflow diff --git a/oneflow/user/ops/grid_sample_op.cpp b/oneflow/user/ops/grid_sample_op.cpp new file mode 100644 index 00000000000..c415d858987 --- /dev/null +++ b/oneflow/user/ops/grid_sample_op.cpp @@ -0,0 +1,181 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +namespace { + +Maybe CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + bool pass_checked = true; + std::stringstream err; + err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; + + const auto& interpolation_mode = conf.attr("interpolation_mode"); + if (!(interpolation_mode == "bilinear" || interpolation_mode == "nearest" + || interpolation_mode == "bicubic")) { + err << " interpolation_mode:" << interpolation_mode; + pass_checked = false; + } + + const auto& padding_mode = conf.attr("padding_mode"); + if (!(padding_mode == "zeros" || padding_mode == "border" || padding_mode == "reflection")) { + err << " padding_mode:" << padding_mode; + pass_checked = false; + } + + if (pass_checked) { + return Maybe::Ok(); + } else { + return oneflow::Error::CheckFailedError() << err.str(); + } +} + +} // namespace + +REGISTER_USER_OP("grid_sample") + .Input("input") + .Input("grid") + .Output("output") + .Attr("interpolation_mode") + .Attr("padding_mode") + .Attr("align_corners") + .SetCheckAttrFn(CheckAttr) + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); + const user_op::TensorDesc& grid = ctx->InputTensorDesc("grid", 0); + user_op::TensorDesc& output = *(ctx->OutputTensorDesc("output", 0)); + // Only support 4D or 5D input with NCHW layout + // For 4D grid: input = { N, C, H_in, W_in }, + // grid = { N, H_out, W_out, 2 } + // output = { N, C, H_out, W_out } + // For 5D grid: input = { N, C, D_in, H_in, W_in }, + // grid = { N, D_out, H_out, W_out, 3 } + // output = { N, C, D_out, H_out, W_out } + const Shape& input_shape = input.shape(); + const Shape& grid_shape = grid.shape(); + + bool is_4d_input = true; + if (input_shape.NumAxes() == 4) { + CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 4) << "Grid and input MUST have same dimention"; + CHECK_EQ_OR_RETURN(grid_shape.At(3), 2) << "Grid shape MUST (N, H_out, W_out, 2)"; + is_4d_input = true; + } else if (input_shape.NumAxes() == 5) { + CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 5) << "Grid and input MUST have same dimention"; + CHECK_EQ_OR_RETURN(grid_shape.At(4), 3) << "Grid shape MUST (N, H_out, W_out, 3)"; + if (ctx->Attr("interpolation_mode") == "bicubic") { + oneflow::Error::CheckFailedError() << "Mode='bicubic' supports only 4-D input"; + } + is_4d_input = false; + } else { + CHECK_OR_RETURN(false) << "MUST be 4D or 5D input"; + } + *output.mut_is_dynamic() = grid.is_dynamic(); + if (is_4d_input) { + *(output.mut_shape()) = {input_shape.At(0), input_shape.At(1), grid_shape.At(1), + grid_shape.At(2)}; + } else { + *(output.mut_shape()) = {input_shape.At(0), input_shape.At(1), grid_shape.At(1), + grid_shape.At(2), grid_shape.At(3)}; + } + return Maybe::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), 0) + .Split(user_op::OpArg("grid", 0), 0) + .Split(user_op::OpArg("output", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), 1) + .Broadcast(user_op::OpArg("grid", 0)) + .Split(user_op::OpArg("output", 0), 1) + .Build(); + return Maybe::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + return Maybe::Ok(); + }); + +REGISTER_USER_OP("grid_sample_grad") + .Input("doutput") + .Input("input") + .Input("grid") + .Output("dinput") + .Output("dgrid") + .Attr("interpolation_mode") + .Attr("padding_mode") + .Attr("align_corners") + .SetCheckAttrFn(CheckAttr) + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { + *(ctx->OutputTensorDesc("dinput", 0)->mut_shape()) = ctx->InputTensorDesc("input", 0).shape(); + *(ctx->OutputTensorDesc("dgrid", 0)->mut_shape()) = ctx->InputTensorDesc("grid", 0).shape(); + return Maybe::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("doutput", 0), 0) + .Split(user_op::OpArg("input", 0), 0) + .Split(user_op::OpArg("grid", 0), 0) + .Split(user_op::OpArg("dinput", 0), 0) + .Split(user_op::OpArg("dgrid", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("doutput", 0), 1) + .Split(user_op::OpArg("input", 0), 1) + .Broadcast(user_op::OpArg("grid", 0)) + .Split(user_op::OpArg("dinput", 0), 1) + .Broadcast(user_op::OpArg("dgrid", 0)) + .Build(); + return Maybe::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("dinput", 0) = ctx->InputDType("input", 0); + *ctx->OutputDType("dgrid", 0) = ctx->InputDType("grid", 0); + return Maybe::Ok(); + }); + +REGISTER_USER_OP_GRAD("grid_sample") + .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, + const user_op::AddOpFn& AddOp) -> Maybe { + if (op.NeedGenGradTensor4OpInput("input", 0) || op.NeedGenGradTensor4OpInput("grid", 0)) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + user_op::UserOpConfWrapper grad_op = + builder.Op("grid_sample_grad") + .Input("doutput", op.GetGradTensorWithOpOutput("output", 0)) + .Input("input", op.input("input", 0)) + .Input("grid", op.input("grid", 0)) + .Output("dinput") + .Output("dgrid") + .Attr("interpolation_mode", op.attr("interpolation_mode")) + .Attr("padding_mode", op.attr("padding_mode")) + .Attr("align_corners", op.attr("align_corners")) + .Build(); + + if (op.NeedGenGradTensor4OpInput("input", 0)) { + op.BindGradTensorWithOpInput(grad_op.output("dinput", 0), "input", 0); + } + if (op.NeedGenGradTensor4OpInput("grid", 0)) { + op.BindGradTensorWithOpInput(grad_op.output("dgrid", 0), "grid", 0); + } + AddOp(grad_op); + } + return Maybe::Ok(); + }); + +} // namespace oneflow diff --git a/python/oneflow/nn/functional/__init__.py b/python/oneflow/nn/functional/__init__.py index 5914c8a800d..ea5959b879b 100644 --- a/python/oneflow/nn/functional/__init__.py +++ b/python/oneflow/nn/functional/__init__.py @@ -15,6 +15,8 @@ """ from oneflow.nn.modules.interpolate import interpolate from oneflow.nn.modules.norm import l2_normalize +from oneflow.nn.modules.affine_grid import affine_grid +from oneflow.nn.modules.grid_sample import grid_sample from oneflow._C import conv1d from oneflow._C import conv2d from oneflow._C import conv3d diff --git a/python/oneflow/nn/modules/affine_grid.py b/python/oneflow/nn/modules/affine_grid.py new file mode 100644 index 00000000000..954762e2ec8 --- /dev/null +++ b/python/oneflow/nn/modules/affine_grid.py @@ -0,0 +1,74 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from typing import List + +import oneflow as flow + + +def affine_grid(theta, size: List[int], align_corners: bool = False): + """The interface is consistent with PyTorch. + The documentation is referenced from: + https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html?highlight=affine_grid#torch.nn.functional.affine_grid + + Generates a 2D or 3D flow field (sampling grid), given a batch of + affine matrices :attr:`theta`. + + .. note:: + This function is often used in conjunction with :func:`grid_sample` + to build `Spatial Transformer Networks`_ . + + Args: + theta (Tensor): input batch of affine matrices with shape + (:math:`N, 2, 3`) for 2D or + (:math:`N, 3, 4`) for 3D + size (flow.Size): the target output image size. + (:math:`N, C, H, W` for 2D or + :math:`N, C, D, H, W` for 3D) + Example: flow.Size((32, 3, 24, 24)) + align_corners (bool): if ``True``, consider ``-1`` and ``1`` + to refer to the centers of the corner pixels rather than the image corners. + Refer to :func:`grid_sample` for a more complete description. + A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample` + with the same setting for this option. + Default: ``False`` + + Returns: + output (Tensor): output Tensor of size (:math:`N, H, W, 2`) + + .. _`Spatial Transformer Networks`: + https://arxiv.org/abs/1506.02025 + + Examples:: + + >>> import oneflow as flow + >>> import numpy as np + >>> input = flow.tensor(np.arange(1., 7).reshape((1, 2, 3)), dtype=flow.float32) + >>> output = flow.nn.functional.affine_grid(input, flow.Size([1, 1, 2, 2]), align_corners=True) + >>> output + tensor([[[[ 0., -3.], + [ 2., 5.]], + + [[ 4., 7.], + [ 6., 15.]]]], dtype=oneflow.float32) + """ + y = flow._C.affine_grid(theta, size=size, align_corners=align_corners) + return y + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/python/oneflow/nn/modules/grid_sample.py b/python/oneflow/nn/modules/grid_sample.py new file mode 100644 index 00000000000..0c4490c80b9 --- /dev/null +++ b/python/oneflow/nn/modules/grid_sample.py @@ -0,0 +1,145 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import oneflow as flow + + +def grid_sample( + input, + grid, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: bool = False, +): + """The interface is consistent with PyTorch. + The documentation is referenced from: + https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html?highlight=grid_sample#torch.nn.functional.grid_sample + + Given an :attr:`input` and a flow-field :attr:`grid`, computes the + ``output`` using :attr:`input` values and pixel locations from :attr:`grid`. + + Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are + supported. + + In the spatial (4-D) case, for :attr:`input` with shape + :math:`(N, C, H_{in}, W_{in})` and :attr:`grid` with shape + :math:`(N, H_{out}, W_{out}, 2)`, the output will have shape + :math:`(N, C, H_{out}, W_{out})`. + + For each output location ``output[n, :, h, w]``, the size-2 vector + ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``, + which are used to interpolate the output value ``output[n, :, h, w]``. + In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the + ``x``, ``y``, ``z`` pixel locations for interpolating + ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or + ``bilinear`` interpolation method to sample the input pixels. + + :attr:`grid` specifies the sampling pixel locations normalized by the + :attr:`input` spatial dimensions. Therefore, it should have most values in + the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the + left-top pixel of :attr:`input`, and values ``x = 1, y = 1`` is the + right-bottom pixel of :attr:`input`. + + If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding + outputs are handled as defined by :attr:`padding_mode`. Options are + + * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations, + * ``padding_mode="border"``: use border values for out-of-bound grid locations, + * ``padding_mode="reflection"``: use values at locations reflected by + the border for out-of-bound grid locations. For location far away + from the border, it will keep being reflected until becoming in bound, + e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1`` + and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes + ``x'' = -0.5``. + + Note: + This function is often used in conjunction with :func:`affine_grid` + to build `Spatial Transformer Networks`_ . + + Note: + NaN values in :attr:`grid` would be interpreted as ``-1``. + + Args: + input (Tensor): input of shape :math:`(N, C, H_{in}, W_{in})` (4-D case) + or :math:`(N, C, D_{in}, H_{in}, W_{in})` (5-D case) + grid (Tensor): flow-field of shape :math:`(N, H_{out}, W_{out}, 2)` (4-D case) + or :math:`(N, D_{out}, H_{out}, W_{out}, 3)` (5-D case) + mode (str): interpolation mode to calculate output values + ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` + Note: ``mode='bicubic'`` supports only 4-D input. + When ``mode='bilinear'`` and the input is 5-D, the interpolation mode + used internally will actually be trilinear. However, when the input is 4-D, + the interpolation mode will legitimately be bilinear. + padding_mode (str): padding mode for outside grid values + ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'`` + align_corners (bool): Geometrically, we consider the pixels of the + input as squares rather than points. + If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring + to the center points of the input's corner pixels. If set to ``False``, they + are instead considered as referring to the corner points of the input's corner + pixels, making the sampling more resolution agnostic. + This option parallels the ``align_corners`` option in + :func:`interpolate`, and so whichever option is used here + should also be used there to resize the input image before grid sampling. + Default: ``False`` + + Returns: + output (Tensor): output Tensor + + .. _`Spatial Transformer Networks`: + https://arxiv.org/abs/1506.02025 + + .. note:: + ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\\alpha=-0.75`. + The constant :math:`\\alpha` might be different from packages to packages. + For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. + This algorithm may "overshoot" the range of values it's interpolating. + For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. + Clamp the results with :func: `flow.clamp` to ensure they are within the valid range. + .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation + .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 + .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 + + Examples:: + + >>> import oneflow as flow + >>> import numpy as np + >>> input = flow.tensor(np.arange(1., 11).reshape((1, 1, 2, 5)), dtype=flow.float32) + >>> np_grid = np.array( + ... [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]], + ... [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]] + ... ).reshape(1, 2, 5, 2) + >>> grid = flow.tensor(np_grid, dtype=flow.float32) + >>> output = flow.nn.functional.grid_sample(input, grid, mode='nearest', padding_mode='zeros', + ... align_corners=True) + >>> output + tensor([[[[0., 8., 5., 7., 9.], + [1., 8., 5., 8., 0.]]]], dtype=oneflow.float32) + """ + y = flow._C.grid_sample( + input, + grid, + interpolation_mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + return y + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/python/oneflow/test/modules/test_affine_grid.py b/python/oneflow/test/modules/test_affine_grid.py new file mode 100644 index 00000000000..c658046be37 --- /dev/null +++ b/python/oneflow/test/modules/test_affine_grid.py @@ -0,0 +1,122 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from random import randint +from random import choice + +import numpy as np +from automated_test_util import * + +import oneflow as flow +import oneflow.unittest + + +class TestAffineGrid(flow.unittest.TestCase): + def test_affine_grid_2d(test_case): + input = flow.tensor(np.arange(1.0, 7).reshape((1, 2, 3)), dtype=flow.float32) + output = flow.nn.functional.affine_grid( + input, flow.Size([1, 1, 2, 2]), align_corners=True + ) + groundtruth = np.array([[[[0.0, -3.0], [2.0, 5.0]], [[4.0, 7.0], [6.0, 15.0]]]]) + test_case.assertTrue( + np.allclose(output.numpy(), groundtruth, rtol=1e-4, atol=1e-8) + ) + + output = flow.nn.functional.affine_grid( + input, flow.Size([1, 1, 2, 2]), align_corners=False + ) + groundtruth = np.array([[[[1.5, 1.5], [2.5, 5.5]], [[3.5, 6.5], [4.5, 10.5]]]]) + test_case.assertTrue( + np.allclose(output.numpy(), groundtruth, rtol=1e-4, atol=1e-8) + ) + + def test_affine_grid_3d(test_case): + input = flow.tensor(np.arange(1.0, 13).reshape((1, 3, 4)), dtype=flow.float32) + output = flow.nn.functional.affine_grid( + input, flow.Size([1, 1, 2, 2, 2]), align_corners=True + ) + groundtruth = np.array( + [ + [ + [ + [[-2.0, -10.0, -18.0], [0.0, 0.0, 0.0]], + [[2.0, 2.0, 2.0], [4.0, 12.0, 20.0]], + ], + [ + [[4.0, 4.0, 4.0], [6.0, 14.0, 22.0]], + [[8.0, 16.0, 24.0], [10.0, 26.0, 42.0]], + ], + ] + ] + ) + test_case.assertTrue( + np.allclose(output.numpy(), groundtruth, rtol=1e-4, atol=1e-8) + ) + + output = flow.nn.functional.affine_grid( + input, flow.Size([1, 1, 2, 2, 2]), align_corners=False + ) + groundtruth = np.array( + [ + [ + [ + [[1.0, -1.0, -3.0], [2.0, 4.0, 6.0]], + [[3.0, 5.0, 7.0], [4.0, 10.0, 16.0]], + ], + [ + [[4.0, 6.0, 8.0], [5.0, 11.0, 17.0]], + [[6.0, 12.0, 18.0], [7.0, 17.0, 27.0]], + ], + ] + ] + ) + test_case.assertTrue( + np.allclose(output.numpy(), groundtruth, rtol=1e-4, atol=1e-8) + ) + + @autotest() + def test_flow_affine_grid_2d_with_random_data(test_case): + N = randint(1, 8) + C = randint(1, 8) + H = randint(1, 8) + W = randint(1, 8) + device = random_device() + align_corners = choice([True, False]) + theta = random_pytorch_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to(device) + output = torch.nn.functional.affine_grid( + theta, (N, C, H, W), align_corners=align_corners + ).to(device) + return output + + @autotest(rtol=1e-03, atol=1e-03) + def test_flow_affine_grid_3d_with_random_data(test_case): + N = randint(1, 8) + C = randint(1, 8) + D = randint(1, 8) + H = randint(1, 8) + W = randint(1, 8) + device = random_device() + align_corners = choice([True, False]) + theta = random_pytorch_tensor(ndim=3, dim0=N, dim1=3, dim2=4).to(device) + output = torch.nn.functional.affine_grid( + theta, (N, C, D, H, W), align_corners=align_corners + ).to(device) + return output + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_grid_sample.py b/python/oneflow/test/modules/test_grid_sample.py new file mode 100644 index 00000000000..ca358591a21 --- /dev/null +++ b/python/oneflow/test/modules/test_grid_sample.py @@ -0,0 +1,141 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from random import randint +from random import choice + +import numpy as np +from automated_test_util import * + +import oneflow as flow +import oneflow.unittest + + +class TestGridSample(flow.unittest.TestCase): + def test_grid_sample_4d(test_case): + input = flow.tensor( + np.arange(1.0, 11).reshape((1, 1, 2, 5)), dtype=flow.float32 + ) + np_grid = np.array( + [ + [[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]], + [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]], + ] + ).reshape(1, 2, 5, 2) + grid = flow.tensor(np_grid, dtype=flow.float32) + groundtruth = np.reshape( + np.array([[0.0, 8.0, 5.0, 7.0, 9.0], [1.0, 8.0, 5.0, 8.0, 0.0]]), + (1, 1, 2, 5), + ) + output = flow.nn.functional.grid_sample( + input, grid, mode="nearest", padding_mode="zeros", align_corners=True + ) + test_case.assertTrue( + np.allclose(output.numpy(), groundtruth, rtol=1e-4, atol=1e-8) + ) + + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + @autotest() + def test_flow_grid_sample_cudnn_with_random_data(test_case): + # cudnn only support 4D input, with mode = 'bilinear' && padding_mode = 'zeros' && align_corners + N = randint(1, 8) + C = randint(1, 8) + in_H = randint(1, 8) + in_W = randint(1, 8) + out_H = randint(1, 8) + out_W = randint(1, 8) + device = "cuda" + mode = "bilinear" + padding_mode = "zeros" + align_corners = True + theta = random_pytorch_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to(device) + grid = torch.nn.functional.affine_grid( + theta, (N, C, out_H, out_W), align_corners=align_corners + ).to(device) + input = random_pytorch_tensor(ndim=4, dim0=N, dim1=C, dim2=in_H, dim3=in_W).to( + device + ) + output = torch.nn.functional.grid_sample( + input, + grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + return output + + @autotest() + def test_flow_grid_sample_4d_with_random_data(test_case): + N = randint(1, 8) + C = randint(1, 8) + in_H = randint(1, 8) + in_W = randint(1, 8) + out_H = randint(1, 8) + out_W = randint(1, 8) + device = random_device() + mode = choice(["bilinear", "nearest", "bicubic"]) + padding_mode = choice(["zeros", "border", "reflection"]) + align_corners = choice([True, False]) + theta = random_pytorch_tensor(ndim=3, dim0=N, dim1=2, dim2=3).to(device) + grid = torch.nn.functional.affine_grid( + theta, (N, C, out_H, out_W), align_corners=align_corners + ).to(device) + input = random_pytorch_tensor(ndim=4, dim0=N, dim1=C, dim2=in_H, dim3=in_W).to( + device + ) + output = torch.nn.functional.grid_sample( + input, + grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + return output + + @autotest(rtol=1e-03, atol=1e-03) + def test_flow_grid_sample_5d_with_random_data(test_case): + N = randint(1, 8) + C = randint(1, 8) + in_D = randint(1, 8) + in_H = randint(1, 8) + in_W = randint(1, 8) + out_D = randint(1, 8) + out_H = randint(1, 8) + out_W = randint(1, 8) + device = random_device() + mode = choice(["bilinear", "nearest"]) + padding_mode = choice(["zeros", "border", "reflection"]) + align_corners = choice([True, False]) + theta = random_pytorch_tensor(ndim=3, dim0=N, dim1=3, dim2=4).to(device) + grid = torch.nn.functional.affine_grid( + theta, (N, C, out_D, out_H, out_W), align_corners=align_corners + ).to(device) + input = random_pytorch_tensor( + ndim=5, dim0=N, dim1=C, dim2=in_D, dim3=in_H, dim4=in_W + ).to(device) + output = torch.nn.functional.grid_sample( + input, + grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + return output + + +if __name__ == "__main__": + unittest.main()