From a949914641db6564c58afdcfc06772322476553e Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 18 Aug 2020 03:51:27 +0000 Subject: [PATCH 1/4] Enhance affine grid operator: 1. Add cuda kernel 2. Add align corners options test=develop --- paddle/fluid/operators/affine_grid_op.cc | 17 +++++- paddle/fluid/operators/affine_grid_op.cu | 58 +++++++++++++++++++ paddle/fluid/operators/affine_grid_op.h | 22 ++++--- python/paddle/fluid/layers/nn.py | 12 ++-- .../tests/unittests/test_affine_grid_op.py | 52 +++++++++++++++-- 5 files changed, 138 insertions(+), 23 deletions(-) create mode 100644 paddle/fluid/operators/affine_grid_op.cu diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index f7cc513b234e6..d1a3695015abd 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -28,10 +28,15 @@ using Tensor = framework::Tensor; template struct Linspace { - void operator()(T start, T end, int count, framework::Tensor* numbers, + void operator()(T start, T end, int count, bool align_corners, + framework::Tensor* numbers, const framework::ExecutionContext& ctx) { T* number_data = numbers->mutable_data({count}, platform::CPUPlace()); T slice = (end - start) / (T)(count - 1); + if (!align_corners) { + slice = (end - start) / (T)count; + start *= (T)(count - 1) / (T)count; + } for (int i = 0; i < count; ++i) { number_data[i] = start + (T)i * slice; } @@ -130,6 +135,10 @@ class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker { "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") .SetDefault(true); + AddAttr("align_corners", + "(bool, default false) Whether to align the corners of input" + "and ouput.") + .SetDefault(true); AddAttr>( "output_shape", "The target output image shape with format [N, C, H, W].") @@ -164,10 +173,12 @@ class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker { [-1. -0.5 0. 0.5 1. ] [-1. -0.5 0. 0.5 1. ] [-1. -0.5 0. 0.5 1. ]]] - C[0] is the coordinates in height axis and C[1] is the coordinates in width axis. + C[0] is the coordinates in height axis and C[1] is the coordinates in + width axis. Step2: - Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get: + Tanspose and reshape C to shape [H * W, 2] and append ones to last + dimension. The we get: C_ = [[-1. -1. 1. ] [-0.5 -1. 1. ] [ 0. -1. 1. ] diff --git a/paddle/fluid/operators/affine_grid_op.cu b/paddle/fluid/operators/affine_grid_op.cu new file mode 100644 index 0000000000000..65723338cce74 --- /dev/null +++ b/paddle/fluid/operators/affine_grid_op.cu @@ -0,0 +1,58 @@ +/* Copyright (c) 2010 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/affine_grid_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) { + CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } +} + +template +struct Linspace { + void operator()(T start, T end, int count, bool align_corners, + framework::Tensor* numbers, + const framework::ExecutionContext& ctx) { + T* number_data = numbers->mutable_data({count}, ctx.GetPlace()); + T slice = (end - start) / (T)(count - 1); + if (!align_corners) { + slice = (end - start) / (T)count; + start *= (T)(count - 1) / (T)count; + } + auto stream = ctx.cuda_device_context().stream(); + int block = 512; + int grid = (count + block - 1) / block; + LinspaceKernel<<>>(start, slice, count, + number_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + affine_grid, + ops::AffineGridOpKernel, + ops::AffineGridOpKernel); +REGISTER_OP_CUDA_KERNEL( + affine_grid_grad, + ops::AffineGridGradOpKernel, + ops::AffineGridGradOpKernel); diff --git a/paddle/fluid/operators/affine_grid_op.h b/paddle/fluid/operators/affine_grid_op.h index 73df8a38b96c3..50c9ebcd9c8f5 100644 --- a/paddle/fluid/operators/affine_grid_op.h +++ b/paddle/fluid/operators/affine_grid_op.h @@ -37,12 +37,13 @@ using Array4 = Eigen::DSizes; */ template struct Linspace { - void operator()(T start, T end, int count, framework::Tensor* numbers, + void operator()(T start, T end, int count, bool align_corners, + framework::Tensor* numbers, const framework::ExecutionContext& ctx); }; template -inline void GetIdxMap(int n, int h, int w, Tensor* grid, +inline void GetIdxMap(int n, int h, int w, bool align_corners, Tensor* grid, const framework::ExecutionContext& ctx) { auto& place = *ctx.template device_context().eigen_device(); grid->mutable_data({n, h, w, 3}, ctx.GetPlace()); @@ -50,16 +51,19 @@ inline void GetIdxMap(int n, int h, int w, Tensor* grid, // Get indexes of height with shape [height, width, 1] Tensor h_idx; Linspace linspace; - linspace((T)-1, (T)1, h, &h_idx, ctx); + linspace((T)-1, (T)1, h, align_corners, &h_idx, ctx); auto h_idx_t = EigenTensor::From(h_idx); // Get indexes of width with shape [height, width, 1] Tensor w_idx; - linspace((T)-1, (T)1, w, &w_idx, ctx); + linspace((T)-1, (T)1, w, align_corners, &w_idx, ctx); auto w_idx_t = EigenTensor::From(w_idx); // Get constant ones tensor with shape [height, width, 1] Tensor ones; ones.mutable_data({h, w, 1}, ctx.GetPlace()); - auto ones_t = EigenTensor::From(ones).setConstant((T)1); + + math::SetConstant()( + ctx.template device_context(), &ones, static_cast(1)); + auto ones_t = EigenTensor::From(ones); // Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and // ones Tensor w_idx_map; @@ -74,11 +78,9 @@ inline void GetIdxMap(int n, int h, int w, Tensor* grid, Tensor w_h_one_idx_map; w_h_one_idx_map.mutable_data({h, w, 3}, ctx.GetPlace()); auto w_h_one_idx_map_t = EigenTensor::From(w_h_one_idx_map); - w_idx_map_t.device(place) = w_idx_t.reshape(Array2(1, w)) .broadcast(Array2(h, 1)) .reshape(Array3(h, w, 1)); - h_idx_map_t.device(place) = h_idx_t.reshape(Array2(1, h)) .broadcast(Array2(w, 1)) .shuffle(Array2(1, 0)) @@ -97,6 +99,7 @@ class AffineGridOpKernel : public framework::OpKernel { auto* theta = ctx.Input("Theta"); int n = theta->dims()[0]; auto size_attr = ctx.Attr>("output_shape"); + auto align_corners = ctx.Attr("align_corners"); int h = 0; int w = 0; if (size_attr.size() == 0) { @@ -116,7 +119,7 @@ class AffineGridOpKernel : public framework::OpKernel { ctx.template device_context(), output, static_cast(0)); Tensor grid; - GetIdxMap(n, h, w, &grid, ctx); + GetIdxMap(n, h, w, align_corners, &grid, ctx); // output = grid * theta.T // TODO(wanghaoshuang): Refine batched matrix multiply auto blas = math::GetBlas(ctx); @@ -140,6 +143,7 @@ class AffineGridGradOpKernel : public framework::OpKernel { auto theta_grad = ctx.Output(framework::GradVarName("Theta")); int n = output_grad->dims()[0]; auto size_attr = ctx.Attr>("output_shape"); + auto align_corners = ctx.Attr("align_corners"); int h = 0; int w = 0; if (size_attr.size() == 0) { @@ -158,7 +162,7 @@ class AffineGridGradOpKernel : public framework::OpKernel { ctx.template device_context(), theta_grad, static_cast(0)); Tensor grid; - GetIdxMap(n, h, w, &grid, ctx); + GetIdxMap(n, h, w, align_corners, &grid, ctx); // output = grid * theta.T // TODO(wanghaoshuang): Refine batched matrix multiply auto blas = math::GetBlas(ctx); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 446510121e72a..747fca4d75a77 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9081,12 +9081,9 @@ def _attr_offsets_check(offset_val): return out -def affine_grid(theta, out_shape, name=None): +def affine_grid(theta, out_shape, name=None, align_corners=True, + use_cudnn=True): """ - :alias_main: paddle.nn.functional.affine_grid - :alias: paddle.nn.functional.affine_grid,paddle.nn.functional.vision.affine_grid - :old_api: paddle.fluid.layers.affine_grid - It generates a grid of (x,y) coordinates using the parameters of the affine transformation that correspond to a set of points where the input feature map should be sampled to produce the transformed @@ -9099,6 +9096,9 @@ def affine_grid(theta, out_shape, name=None): ``out_shape`` can be a Tensor or a list or tuple. The data type must be int32. name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + align_corners(bool): Whether to align corners of target feature map and source feature map. Default: True. + use_cudnn(bool): It will ignore `align_corners` and compute in align corners mode when use_cudnn is true. + Default: True. Returns: Variable: A Tensor with shape [batch_size, H, W, 2] while 'H' and 'W' are the height and width of feature map in affine transformation. The data type is the same as `theta`. @@ -9140,7 +9140,7 @@ def affine_grid(theta, out_shape, name=None): out = helper.create_variable_for_type_inference(theta.dtype) ipts = {'Theta': theta} - attrs = {} + attrs = {"align_corners": align_corners, "use_cudnn": use_cudnn} if isinstance(out_shape, Variable): ipts['OutputShape'] = out_shape check_variable_and_dtype(out_shape, 'out_shape', ['int32'], diff --git a/python/paddle/fluid/tests/unittests/test_affine_grid_op.py b/python/paddle/fluid/tests/unittests/test_affine_grid_op.py index 3668c4f4aa174..55612d71a17a7 100644 --- a/python/paddle/fluid/tests/unittests/test_affine_grid_op.py +++ b/python/paddle/fluid/tests/unittests/test_affine_grid_op.py @@ -17,14 +17,20 @@ from op_test import OpTest -def AffineGrid(theta, size): +def AffineGrid(theta, size, align_corners): n = size[0] w = size[3] h = size[2] + h_factor = w_factor = 1 + if not align_corners: + h_factor = (h - 1) / float(h) + w_factor = (w - 1) / float(w) h_idx = np.repeat( - np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis] + np.linspace(-1, 1, h)[np.newaxis, :], w, + axis=0).T[:, :, np.newaxis] * h_factor w_idx = np.repeat( - np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis] + np.linspace(-1, 1, w)[np.newaxis, :], h, + axis=0)[:, :, np.newaxis] * w_factor grid = np.concatenate( [w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3 grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3 @@ -45,12 +51,17 @@ def setUp(self): theta = np.random.randint(1, 3, self.theta_shape).astype("float32") theta = np.ones(self.theta_shape).astype("float32") self.inputs = {'Theta': theta} - self.attrs = {"use_cudnn": True} + self.attrs = { + "use_cudnn": self.use_cudnn, + "align_corners": self.align_corners + } if self.dynamic_shape: self.inputs['OutputShape'] = self.output_shape else: self.attrs['output_shape'] = self.output_shape - self.outputs = {'Output': AffineGrid(theta, self.output_shape)} + self.outputs = { + 'Output': AffineGrid(theta, self.output_shape, self.align_corners) + } def test_check_output(self): self.check_output() @@ -62,6 +73,8 @@ def initTestCase(self): self.theta_shape = (17, 2, 3) self.output_shape = np.array([17, 2, 5, 7]).astype("int32") self.dynamic_shape = False + self.use_cudnn = False + self.align_corners = True class TestAffineGridOpCase1(TestAffineGridOp): @@ -69,6 +82,35 @@ def initTestCase(self): self.theta_shape = (20, 2, 3) self.output_shape = np.array([20, 2, 5, 7]).astype("int32") self.dynamic_shape = True + self.use_cudnn = True + self.align_corners = True + + +class TestAffineGridOpCase2(TestAffineGridOp): + def initTestCase(self): + self.theta_shape = (20, 2, 3) + self.output_shape = np.array([20, 2, 5, 7]).astype("int32") + self.dynamic_shape = True + self.use_cudnn = False + self.align_corners = True + + +class TestAffineGridOpCase3(TestAffineGridOp): + def initTestCase(self): + self.theta_shape = (20, 2, 3) + self.output_shape = np.array([20, 2, 5, 7]).astype("int32") + self.dynamic_shape = True + self.use_cudnn = False + self.align_corners = False + + +class TestAffineGridOpCase4(TestAffineGridOp): + def initTestCase(self): + self.theta_shape = (25, 2, 3) + self.output_shape = np.array([25, 2, 5, 6]).astype("int32") + self.dynamic_shape = False + self.use_cudnn = False + self.align_corners = False if __name__ == '__main__': From d91f592f2d0a8a29c2a6a3710b19f4d0538dde9e Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 19 Aug 2020 04:30:55 +0000 Subject: [PATCH 2/4] Move new affine_grid api to functional test=develop --- python/paddle/fluid/layers/nn.py | 8 +- .../unittests/test_affine_grid_function.py | 134 ++++++++++++++++++ python/paddle/nn/functional/vision.py | 90 +++++++++++- 3 files changed, 225 insertions(+), 7 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_affine_grid_function.py diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 747fca4d75a77..724866b06b35d 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9081,8 +9081,7 @@ def _attr_offsets_check(offset_val): return out -def affine_grid(theta, out_shape, name=None, align_corners=True, - use_cudnn=True): +def affine_grid(theta, out_shape, name=None): """ It generates a grid of (x,y) coordinates using the parameters of the affine transformation that correspond to a set of points where @@ -9096,9 +9095,6 @@ def affine_grid(theta, out_shape, name=None, align_corners=True, ``out_shape`` can be a Tensor or a list or tuple. The data type must be int32. name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - align_corners(bool): Whether to align corners of target feature map and source feature map. Default: True. - use_cudnn(bool): It will ignore `align_corners` and compute in align corners mode when use_cudnn is true. - Default: True. Returns: Variable: A Tensor with shape [batch_size, H, W, 2] while 'H' and 'W' are the height and width of feature map in affine transformation. The data type is the same as `theta`. @@ -9140,7 +9136,7 @@ def affine_grid(theta, out_shape, name=None, align_corners=True, out = helper.create_variable_for_type_inference(theta.dtype) ipts = {'Theta': theta} - attrs = {"align_corners": align_corners, "use_cudnn": use_cudnn} + attrs = {} if isinstance(out_shape, Variable): ipts['OutputShape'] = out_shape check_variable_and_dtype(out_shape, 'out_shape', ['int32'], diff --git a/python/paddle/fluid/tests/unittests/test_affine_grid_function.py b/python/paddle/fluid/tests/unittests/test_affine_grid_function.py new file mode 100644 index 0000000000000..ecef9a3456e94 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_affine_grid_function.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from paddle import fluid, nn +import paddle.fluid.dygraph as dg +import paddle.nn.functional as F +import paddle.fluid.initializer as I +import unittest + + +class AffineGridTestCase(unittest.TestCase): + def __init__(self, + methodName='runTest', + theta_shape=(20, 2, 3), + output_shape=[20, 2, 5, 7], + align_corners=True, + dtype="float32"): + super(AffineGridTestCase, self).__init__(methodName) + + self.theta_shape = theta_shape + self.output_shape = output_shape + self.align_corners = align_corners + self.dtype = dtype + + def setUp(self): + self.theta = np.random.randn(*(self.theta_shape)).astype(self.dtype) + + def fluid_layer(self, place): + # align_corners = True + main = fluid.Program() + start = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, start): + theta_var = fluid.data( + "input", self.theta_shape, dtype=self.dtype) + y_var = fluid.layers.affine_grid(theta_var, self.output_shape) + feed_dict = {"input": self.theta} + exe = fluid.Executor(place) + exe.run(start) + y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var]) + return y_np + + def functional(self, place): + main = fluid.Program() + start = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, start): + theta_var = fluid.data( + "input", self.theta_shape, dtype=self.dtype) + y_var = F.affine_grid( + theta_var, + self.output_shape, + align_corners=self.align_corners) + feed_dict = {"input": self.theta} + exe = fluid.Executor(place) + exe.run(start) + y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var]) + return y_np + + def paddle_dygraph_layer(self): + theta_var = dg.to_variable(self.theta) + y_var = F.affine_grid( + theta_var, self.output_shape, align_corners=self.align_corners) + y_np = y_var.numpy() + return y_np + + def _test_equivalence(self, place): + place = fluid.CPUPlace() + result1 = self.fluid_layer(place) + result2 = self.functional(place) + with dg.guard(place): + result3 = self.paddle_dygraph_layer() + if self.align_corners: + np.testing.assert_array_almost_equal(result1, result2) + np.testing.assert_array_almost_equal(result2, result3) + + def runTest(self): + place = fluid.CPUPlace() + self._test_equivalence(place) + + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self._test_equivalence(place) + + +class AffineGridErrorTestCase(AffineGridTestCase): + def runTest(self): + place = fluid.CPUPlace() + with dg.guard(place): + with self.assertRaises(ValueError): + self.paddle_dygraph_layer() + + +def add_cases(suite): + suite.addTest(AffineGridTestCase(methodName='runTest')) + suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=True)) + + suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=False)) + + suite.addTest( + AffineGridTestCase( + methodName='runTest', + theta_shape=(20, 2, 3), + output_shape=[20, 1, 7, 7], + align_corners=True)) + + +def add_error_cases(suite): + suite.addTest( + AffineGridErrorTestCase( + methodName='runTest', output_shape="not_valid")) + + +def load_tests(loader, standard_tests, pattern): + suite = unittest.TestSuite() + add_cases(suite) + add_error_cases(suite) + return suite + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index a2cc8fde5ad71..8389e2a10bdf6 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ...device import get_cudnn_version +from ...fluid.framework import core, in_dygraph_mode, Variable +from ...fluid.layer_helper import LayerHelper +from ...fluid.data_feeder import check_variable_and_dtype + # TODO: define specitial functions used in computer vision task from ...fluid.layers import affine_channel #DEFINE_ALIAS -from ...fluid.layers import affine_grid #DEFINE_ALIAS from ...fluid.layers import anchor_generator #DEFINE_ALIAS from ...fluid.layers import bipartite_match #DEFINE_ALIAS from ...fluid.layers import box_clip #DEFINE_ALIAS @@ -89,3 +93,87 @@ 'yolo_box', 'yolov3_loss' ] + + +def affine_grid(theta, out_shape, align_corners=True, name=None): + """ + It generates a grid of (x,y) coordinates using the parameters of + the affine transformation that correspond to a set of points where + the input feature map should be sampled to produce the transformed + output feature map. + + Args: + theta (Variable) - A Tensor with shape [N, 2, 3]. It contains a batch of affine transform parameters. + The data type can be float32 or float64. + out_shape (Variable | list | tuple): The shape of target output with format [batch_size, channel, height, width]. + ``out_shape`` can be a Tensor or a list or tuple. The data + type must be int32. + align_corners(bool): Whether to align corners of target feature map and source feature map. Default: True. + name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A Tensor with shape [batch_size, H, W, 2] while 'H' and 'W' are the height and width of feature map in affine transformation. The data type is the same as `theta`. + + Raises: + ValueError: If the type of arguments is not supported. + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + place = paddle.CPUPlace() + theta_shape = [20, 2, 3] + theta = np.random.randn(*theta_shape).astype("float32") + theta_var = paddle.to_variable(theta) + y_var = F.affine_grid( + theta_var, + [20, 2, 5, 5], + align_corners=False) + y_np = y_var.numpy() + print(y_np) + """ + helper = LayerHelper('affine_grid') + + check_variable_and_dtype(theta, 'theta', ['float32', 'float64'], + 'affine_grid') + + if get_cudnn_version() >= 6000 and align_corners: + use_cudnn = True + else: + use_cudnn = False + + if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \ + isinstance(out_shape, Variable)): + raise ValueError("The out_shape should be a list, tuple or Variable.") + + if in_dygraph_mode(): + _out_shape = out_shape.numpy().tolist() if isinstance( + out_shape, Variable) else out_shape + return core.ops.affine_grid(theta, "output_shape", _out_shape, + "align_corners", align_corners, "use_cudnn", + use_cudnn) + + if not isinstance(theta, Variable): + raise ValueError("The theta should be a Variable.") + + out = helper.create_variable_for_type_inference(theta.dtype) + ipts = {'Theta': theta} + attrs = {"align_corners": align_corners, "use_cudnn": use_cudnn} + if isinstance(out_shape, Variable): + ipts['OutputShape'] = out_shape + check_variable_and_dtype(out_shape, 'out_shape', ['int32'], + 'affine_grid') + else: + attrs['output_shape'] = out_shape + + helper.append_op( + type='affine_grid', + inputs=ipts, + outputs={'Output': out}, + attrs=None if len(attrs) == 0 else attrs) + return out From 55a26f1646a8c7779239c28e0f2e6559f910bd8c Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 24 Aug 2020 03:52:25 +0000 Subject: [PATCH 3/4] Add CUDA kernel for affine_grid. test=develop --- paddle/fluid/operators/affine_grid_op.cu | 171 +++++++++++++++++++++-- python/paddle/nn/functional/vision.py | 42 ++++-- 2 files changed, 188 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/affine_grid_op.cu b/paddle/fluid/operators/affine_grid_op.cu index 65723338cce74..eca8246533fea 100644 --- a/paddle/fluid/operators/affine_grid_op.cu +++ b/paddle/fluid/operators/affine_grid_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2010 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/affine_grid_op.h" - +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/gpu_info.h" namespace paddle { namespace operators { @@ -44,15 +45,165 @@ struct Linspace { } }; +template +__global__ void affine_grid_kernel(const int count, int n, int out_h, int out_w, + T h_start, T w_start, T h_step, T w_step, + const T* theta, // N, 2, 3 + T* output) { + CUDA_KERNEL_LOOP(index, count) { + int w = index % out_w; + int h = (index / out_w) % out_h; + int n = index / (out_w * out_h); + + T h_coor = h_step * static_cast(h) + static_cast(h_start); + T w_coor = w_step * static_cast(w) + static_cast(w_start); + + int theta_offset = n * 6; // 2 * 3; + // affine from (h_coor, w_coor) to (x, y) + output[index * 2] = theta[theta_offset] * h_coor + + theta[theta_offset + 1] * w_coor + + theta[theta_offset + 2]; + output[index * 2 + 1] = theta[theta_offset + 3] * h_coor + + theta[theta_offset + 4] * w_coor + + theta[theta_offset + 5]; + } +} + +template +__global__ void affine_grid_grad_kernel(const int count, int n, int out_h, + int out_w, T h_start, T w_start, + T h_step, T w_step, + const T* out_grad, // N, H, W, 2 + T* theta_grad) { // N, 2, 3 + CUDA_KERNEL_LOOP(index, count) { + int w = index % out_w; + int h = (index / out_w) % out_h; + int n = index / (out_w * out_h); + T h_coor = h_step * static_cast(h) + static_cast(h_start); + T w_coor = w_step * static_cast(w) + static_cast(w_start); + + int theta_offset = n * 6; // 2 * 3; + T out_grad_x = out_grad[index * 2]; + atomicAdd(theta_grad + theta_offset, out_grad_x * h_coor); + atomicAdd(theta_grad + theta_offset + 1, out_grad_x * w_coor); + atomicAdd(theta_grad + theta_offset + 2, out_grad_x); + + T out_grad_y = out_grad[index * 2 + 1]; + atomicAdd(theta_grad + theta_offset + 3, out_grad_y * h_coor); + atomicAdd(theta_grad + theta_offset + 4, out_grad_y * w_coor); + atomicAdd(theta_grad + theta_offset + 5, out_grad_y); + } +} + +template +class AffineGridOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* theta = ctx.Input("Theta"); + int n = theta->dims()[0]; + auto size_attr = ctx.Attr>("output_shape"); + auto align_corners = ctx.Attr("align_corners"); + int h = 0; + int w = 0; + if (size_attr.size() == 0) { + auto* output_shape = ctx.Input("OutputShape"); + Tensor h_sizes; + framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); + const int* h_size_data = h_sizes.data(); + h = h_size_data[2]; + w = h_size_data[3]; + } else { + h = size_attr[2]; + w = size_attr[3]; + } + auto* output = ctx.Output("Output"); + T* out_data = output->mutable_data({n, h, w, 2}, ctx.GetPlace()); + + T h_step; + T w_step; + T h_start = -1; + T w_start = -1; + if (align_corners) { + h_step = static_cast(2) / static_cast(h - 1); + w_step = static_cast(2) / static_cast(w - 1); + } else { + h_step = static_cast(2) / static_cast(h); + w_step = static_cast(2) / static_cast(w); + + h_start *= static_cast(h - 1) / static_cast(h); + w_start *= static_cast(w - 1) / static_cast(w); + } + + const int count = n * h * w; + int block = 512; + int grid = (count + block - 1) / block; + auto cu_stream = ctx.cuda_device_context().stream(); + affine_grid_kernel<<>>( + count, n, h, w, h_start, w_start, h_step, w_step, + theta->data(), // N, 2, 3 + out_data); + } +}; + +template +class AffineGridGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto theta_grad = ctx.Output(framework::GradVarName("Theta")); + int n = output_grad->dims()[0]; + auto size_attr = ctx.Attr>("output_shape"); + auto align_corners = ctx.Attr("align_corners"); + int h = 0; + int w = 0; + if (size_attr.size() == 0) { + auto* output_shape = ctx.Input("OutputShape"); + Tensor h_sizes; + framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes); + const int* h_size_data = h_sizes.data(); + h = h_size_data[2]; + w = h_size_data[3]; + } else { + h = size_attr[2]; + w = size_attr[3]; + } + T* theta_grad_data = theta_grad->mutable_data({n, 2, 3}, ctx.GetPlace()); + math::SetConstant()( + ctx.cuda_device_context(), theta_grad, static_cast(0)); + + T h_step; + T w_step; + T h_start = -1; + T w_start = -1; + if (align_corners) { + h_step = static_cast(2) / static_cast(h - 1); + w_step = static_cast(2) / static_cast(w - 1); + } else { + h_step = static_cast(2) / static_cast(h); + w_step = static_cast(2) / static_cast(w); + + h_start *= static_cast(h - 1) / static_cast(h); + w_start *= static_cast(w - 1) / static_cast(w); + } + const int count = n * h * w; + VLOG(3) << "count: " << count << "; h_step: " << h_step + << "; w_step: " << w_step << "; h_start: " << h_start + << "; w_start: " << w_start; + int block = 512; + int grid = (count + block - 1) / block; + auto cu_stream = ctx.cuda_device_context().stream(); + affine_grid_grad_kernel<<>>( + count, n, h, w, h_start, w_start, h_step, w_step, + output_grad->data(), theta_grad_data); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - affine_grid, - ops::AffineGridOpKernel, - ops::AffineGridOpKernel); -REGISTER_OP_CUDA_KERNEL( - affine_grid_grad, - ops::AffineGridGradOpKernel, - ops::AffineGridGradOpKernel); +REGISTER_OP_CUDA_KERNEL(affine_grid, ops::AffineGridOpCUDAKernel, + ops::AffineGridOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(affine_grid_grad, + ops::AffineGridGradOpCUDAKernel, + ops::AffineGridGradOpCUDAKernel); diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 8389e2a10bdf6..d974b77b9b57d 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -103,16 +103,16 @@ def affine_grid(theta, out_shape, align_corners=True, name=None): output feature map. Args: - theta (Variable) - A Tensor with shape [N, 2, 3]. It contains a batch of affine transform parameters. + theta (Tensor) - A tensor with shape [N, 2, 3]. It contains a batch of affine transform parameters. The data type can be float32 or float64. - out_shape (Variable | list | tuple): The shape of target output with format [batch_size, channel, height, width]. + out_shape (Tensor | list | tuple): The shape of target output with format [batch_size, channel, height, width]. ``out_shape`` can be a Tensor or a list or tuple. The data type must be int32. align_corners(bool): Whether to align corners of target feature map and source feature map. Default: True. name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor: A Tensor with shape [batch_size, H, W, 2] while 'H' and 'W' are the height and width of feature map in affine transformation. The data type is the same as `theta`. + Tensor, A Tensor with shape [batch_size, H, W, 2] while 'H' and 'W' are the height and width of feature map in affine transformation. The data type is the same as `theta`. Raises: ValueError: If the type of arguments is not supported. @@ -126,30 +126,42 @@ def affine_grid(theta, out_shape, align_corners=True, name=None): import numpy as np paddle.disable_static() - place = paddle.CPUPlace() - theta_shape = [20, 2, 3] - theta = np.random.randn(*theta_shape).astype("float32") - theta_var = paddle.to_variable(theta) - y_var = F.affine_grid( - theta_var, - [20, 2, 5, 5], + # theta shape = [1, 2, 3] + theta = np.array([[[-0.7, -0.4, 0.3], + [ 0.6, 0.5, 1.5]]]).astype("float32") + theta_t = paddle.to_tensor(theta) + y_t = F.affine_grid( + theta_t, + [1, 2, 3, 3], align_corners=False) - y_np = y_var.numpy() - print(y_np) + print(y_t.numpy()) + + #[[[[ 1.0333333 0.76666665] + # [ 0.76666665 1.0999999 ] + # [ 0.5 1.4333333 ]] + # + # [[ 0.5666667 1.1666666 ] + # [ 0.3 1.5 ] + # [ 0.03333333 1.8333334 ]] + # + # [[ 0.10000002 1.5666667 ] + # [-0.16666666 1.9000001 ] + # [-0.43333334 2.2333333 ]]]] """ helper = LayerHelper('affine_grid') check_variable_and_dtype(theta, 'theta', ['float32', 'float64'], 'affine_grid') - if get_cudnn_version() >= 6000 and align_corners: + cudnn_version = get_cudnn_version() + if cudnn_version is not None and cudnn_version >= 6000 and align_corners: use_cudnn = True else: use_cudnn = False if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \ isinstance(out_shape, Variable)): - raise ValueError("The out_shape should be a list, tuple or Variable.") + raise ValueError("The out_shape should be a list, tuple or Tensor.") if in_dygraph_mode(): _out_shape = out_shape.numpy().tolist() if isinstance( @@ -159,7 +171,7 @@ def affine_grid(theta, out_shape, align_corners=True, name=None): use_cudnn) if not isinstance(theta, Variable): - raise ValueError("The theta should be a Variable.") + raise ValueError("The theta should be a Tensor.") out = helper.create_variable_for_type_inference(theta.dtype) ipts = {'Theta': theta} From dd462092fa3add1b80993dc0b960e7ec564efdaa Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 24 Aug 2020 14:58:02 +0000 Subject: [PATCH 4/4] Add more unitest for grid sample API test=develop --- .../unittests/test_affine_grid_function.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_affine_grid_function.py b/python/paddle/fluid/tests/unittests/test_affine_grid_function.py index ecef9a3456e94..5cfab78fda988 100644 --- a/python/paddle/fluid/tests/unittests/test_affine_grid_function.py +++ b/python/paddle/fluid/tests/unittests/test_affine_grid_function.py @@ -26,13 +26,17 @@ def __init__(self, theta_shape=(20, 2, 3), output_shape=[20, 2, 5, 7], align_corners=True, - dtype="float32"): + dtype="float32", + invalid_theta=False, + variable_output_shape=False): super(AffineGridTestCase, self).__init__(methodName) self.theta_shape = theta_shape self.output_shape = output_shape self.align_corners = align_corners self.dtype = dtype + self.invalid_theta = invalid_theta + self.variable_output_shape = variable_output_shape def setUp(self): self.theta = np.random.randn(*(self.theta_shape)).astype(self.dtype) @@ -70,9 +74,12 @@ def functional(self, place): return y_np def paddle_dygraph_layer(self): - theta_var = dg.to_variable(self.theta) + theta_var = dg.to_variable( + self.theta) if not self.invalid_theta else "invalid" + output_shape = dg.to_variable( + self.output_shape) if variable_output_shape else self.output_shape y_var = F.affine_grid( - theta_var, self.output_shape, align_corners=self.align_corners) + theta_var, output_shape, align_corners=self.align_corners) y_np = y_var.numpy() return y_np @@ -108,6 +115,9 @@ def add_cases(suite): suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=True)) suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=False)) + suite.addTest( + AffineGridTestCase( + methodName='runTest', variable_output_shape=True)) suite.addTest( AffineGridTestCase( @@ -121,6 +131,10 @@ def add_error_cases(suite): suite.addTest( AffineGridErrorTestCase( methodName='runTest', output_shape="not_valid")) + suite.addTest( + AffineGridErrorTestCase( + methodName='runTest', + invalid_theta=True)) # to test theta not variable error checking def load_tests(loader, standard_tests, pattern):