From 671d2ff850df6cb4cc828fdcb1e10e3dac35bb41 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 10 Aug 2020 14:09:31 +0000 Subject: [PATCH 01/15] add pad func --- python/paddle/nn/functional/common.py | 195 +++++++++++++++++++++++++- 1 file changed, 193 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index fe41cb6e64c34..f83dd74a4cbf9 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -14,16 +14,17 @@ import warnings from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.layers.tensor import Variable, fill_constant +from paddle.fluid.layers.tensor import Variable, fill_constant, zeros, concat # TODO: define the common functions to build a neural network from ...fluid.layers import dropout #DEFINE_ALIAS from ...fluid.layers import label_smooth #DEFINE_ALIAS from ...fluid import one_hot #DEFINE_ALIAS -from ...fluid.layers import pad #DEFINE_ALIAS from ...fluid.layers import pad2d #DEFINE_ALIAS from ...fluid.layers import unfold #DEFINE_ALIAS from ...fluid.layers import assign #DEFINE_ALIAS +from ...fluid.layers import squeeze +from ...fluid.layers import unsqueeze #from ...fluid.layers import fc #DEFINE_ALIAS from ...fluid.layers import pad_constant_like #DEFINE_ALIAS @@ -446,3 +447,193 @@ def _is_list_or_turple_(data): outputs={"Out": out}, attrs=attrs) return out + + +def pad(input, + pad=[0, 0, 0, 0], + mode='constant', + value=0, + data_format="NCHW", + name=None): + """ + :alias_main: paddle.nn.functional.pad + :alias: paddle.nn.functional.pad,paddle.nn.functional.common.pad + :old_api: paddle.fluid.layers.pad2d + + Pad tensor according to 'pad' and 'mode'. + If mode is 'reflect', pad[0] and pad[1] must be no greater + than width-1. The height and depth dimension has the same condition. + + Parameters: + input (Variable): The input tensor with data type float32/double/int32/int64_t. + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. 1. If input dimension is 3, then the pad has the form (pad_left, + pad_right). 2. If the input dimension is 4, then the pad has the form (pad_left, pad_right, + pad_top, pad_bottom). 3. If the input dimension is 5, then the pad has the form + (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). Default is [0, 0, 0, 0]. + + mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. + When in 'constant' mode, this op uses a constant value to pad the input tensor. + When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor. + When in 'replicate' mode, uses input boundaries to pad the input tensor. + When in 'circular' mode, uses circular input to pad the input tensor. + Default is 'constant' + value (float32): The value to fill the padded areas in 'constant' mode . Default is 0.0 + data_format (str): An string from: "NCL", "NLC", NHWC", "NCHW", "NCDHW", "NDHWC". Specify the data format of + the input data. + Default is "NCHW" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: a Tensor padded according to pad and mode and data type is same as input. + Return Type: Variable + + Examples: + .. code-block:: text + + Input = [[[[1., 2., 3.], + [4., 5., 6.]]]] + + Case 0: + pad = [2, 2, 1, 1, 0, 0], + mode = 'constant' + pad_value = 0 + Out = [[[[[0. 0. 0. 0. 0. 0. 0.] + [0. 0. 1. 2. 3. 0. 0.] + [0. 0. 4. 5. 6. 0. 0.] + [0. 0. 0. 0. 0. 0. 0.]]]]] + + Case 1: + pad = [2, 2, 1, 1, 0, 0], + mode = 'reflect' + Out = [[[[[6. 5. 4. 5. 6. 5. 4.] + [3. 2. 1. 2. 3. 2. 1.] + [6. 5. 4. 5. 6. 5. 4.] + [3. 2. 1. 2. 3. 2. 1.]]]]] + + Case 2: + pad = [2, 2, 1, 1, 0, 0], + mode = 'replicate' + Out = [[[[[1. 1. 1. 2. 3. 3. 3.] + [1. 1. 1. 2. 3. 3. 3.] + [4. 4. 4. 5. 6. 6. 6.] + [4. 4. 4. 5. 6. 6. 6.]]]]] + + Case 3: + pad = [2, 2, 1, 1, 0, 0], + mode = 'circular' + Out = [[[[[5. 6. 4. 5. 6. 4. 5.] + [2. 3. 1. 2. 3. 1. 2.] + [5. 6. 4. 5. 6. 4. 5.] + [2. 3. 1. 2. 3. 1. 2.]]]]] + + Code Examples: + .. code-block:: python + # declarative mode + import numpy as np + import paddle + import paddle.nn.functional as F + + input_shape = (1, 1, 3) + data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + x = paddle.data(name="x", shape=input_shape) + y = F.pad(x, pad=[5, 6], value=1, mode='constant') + place = paddle.CPUPlace() + exe = paddle.Executor(place) + outputs = exe.run(feed={'x': data}, fetch_list=[y.name]) + print(outputs[0]) + # [[[1. 1. 1. 2. 3. 1. 1. 1.]]] + + # imperative mode + import paddle.fluid.dygraph as dg + input_shape = (1, 1, 2, 3) + pad = [1, 2, 1, 1] + input_data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + with dg.guard(place) as g: + input = dg.to_variable(input_data) + output = paddle.nn.functional.pad(input=input, pad=pad, mode="circular") + print(output.numpy()) + # [[[[6. 4. 5. 6. 4. 5.] + # [3. 1. 2. 3. 1. 2.] + # [6. 4. 5. 6. 4. 5.] + # [3. 1. 2. 3. 1. 2.]]]] + + """ + data_format = data_format.upper() + + input_dim = len(input.shape) + + original_data_format = data_format + unsqueezed_dim = [] + + if isinstance(pad, Variable): + if data_format in ["NCL", "NCHW", "NCDHW"]: + data_format = "NCDHW" + if input_dim == 3: + pad = concat([zeros((4, ), dtype="int32"), pad], axis=0) + unsqueezed_dim = [3, 4] + input = unsqueeze(input, axes=unsqueezed_dim) + elif input_dim == 4: + pad = concat([pad, zeros((2, ), dtype="int32")], axis=0) + unsqueezed_dim = [2] + input = unsqueeze(input, axes=unsqueezed_dim) + elif data_format in ["NLC", "NHWC", "NDHWC"]: + data_format = "NDHWC" + if input_dim == 3: + pad = concat([zeros((4, ), dtype="int32"), pad], axis=0) + unsqueezed_dim = [2, 3] + input = unsqueeze(input, axes=unsqueezed_dim) + elif input_dim == 4: + pad = concat([pad, zeros((2, ), dtype="int32")], axis=0) + unsqueezed_dim = [1] + input = unsqueeze(input, axes=unsqueezed_dim) + else: + if data_format in ["NCL", "NCHW", "NCDHW"]: + data_format = "NCDHW" + if input_dim == 3: + pad = [0, 0, 0, 0] + pad + unsqueezed_dim = [3, 4] + input = unsqueeze(input, axes=unsqueezed_dim) + elif input_dim == 4: + pad = pad + [0, 0] + unsqueezed_dim = [2] + input = unsqueeze(input, axes=unsqueezed_dim) + elif data_format in ["NLC", "NHWC", "NDHWC"]: + data_format = "NDHWC" + if input_dim == 3: + pad = [0, 0, 0, 0] + pad + unsqueezed_dim = [2, 3] + input = unsqueeze(input, axes=unsqueezed_dim) + elif input_dim == 4: + pad = pad + [0, 0] + unsqueezed_dim = [1] + input = unsqueeze(input, axes=unsqueezed_dim) + else: + raise ValueError, "data_format should be in one of " + "[NCL, NCHW, NCDHW, NLC, NHWC, NDHWC] but got {}".format( + data_format) + + attrs = {'mode': mode, 'value': value, 'data_format': data_format} + + inputs = {'X': [input]} + if isinstance(pad, Variable): + inputs['Paddings'] = [pad] + attrs['paddings'] = [] + else: + attrs['paddings'] = pad + + helper = LayerHelper('pad3d', **locals()) + + assert mode in ['reflect', 'replicate', 'constant', 'circular'], \ + "mode should be one of constant, reflect, replicate, circular, but got {}.".format(mode) + + dtype = helper.input_dtype(input_param_name='input') + out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type='pad3d', inputs=inputs, outputs={"Out": out}, attrs=attrs) + + if len(unsqueezed_dim) != 0: + out = squeeze(out, axes=unsqueezed_dim) + + return out From 53a713f379f81a126bbe00b28795b6772ecb9e16 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 11 Aug 2020 06:02:37 +0000 Subject: [PATCH 02/15] add pad --- python/paddle/nn/__init__.py | 5 +- python/paddle/nn/functional/common.py | 16 +- python/paddle/nn/layer/__init__.py | 5 +- python/paddle/nn/layer/common.py | 310 +++++++++++++++++++++----- 4 files changed, 275 insertions(+), 61 deletions(-) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 98948fa91e2e8..c41270c96503c 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -60,7 +60,10 @@ from .layer.activation import HSigmoid #DEFINE_ALIAS from .layer.common import BilinearTensorProduct #DEFINE_ALIAS from .layer.common import Pool2D #DEFINE_ALIAS -from .layer.common import Pad2D #DEFINE_ALIAS +from .layer.common import Pad #DEFINE_ALIAS +from .layer.common import ReflectionPad1d #DEFINE_ALIAS +from .layer.common import ReplicationPad1d #DEFINE_ALIAS +from .layer.common import ConstantPad1d #DEFINE_ALIAS from .layer.common import Embedding #DEFINE_ALIAS from .layer.common import Linear #DEFINE_ALIAS from .layer.common import Flatten #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index f83dd74a4cbf9..32a191b47d8e7 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -473,17 +473,17 @@ def pad(input, (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). Default is [0, 0, 0, 0]. mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. - When in 'constant' mode, this op uses a constant value to pad the input tensor. - When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor. - When in 'replicate' mode, uses input boundaries to pad the input tensor. + When in 'constant' mode, this op uses a constant value to pad the input tensor. + When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor. + When in 'replicate' mode, uses input boundaries to pad the input tensor. When in 'circular' mode, uses circular input to pad the input tensor. - Default is 'constant' + Default is 'constant' value (float32): The value to fill the padded areas in 'constant' mode . Default is 0.0 data_format (str): An string from: "NCL", "NLC", NHWC", "NCHW", "NCDHW", "NDHWC". Specify the data format of - the input data. - Default is "NCHW" + the input data. + Default is "NCHW" name (str, optional) : The default value is None. Normally there is no need for - user to set this property. For more information, please refer to :ref:`api_guide_Name`. + user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: a Tensor padded according to pad and mode and data type is same as input. Return Type: Variable @@ -537,7 +537,7 @@ def pad(input, input_shape = (1, 1, 3) data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 x = paddle.data(name="x", shape=input_shape) - y = F.pad(x, pad=[5, 6], value=1, mode='constant') + y = F.pad(x, pad=[2, 3], value=1, mode='constant') place = paddle.CPUPlace() exe = paddle.Executor(place) outputs = exe.run(feed={'x': data}, fetch_list=[y.name]) diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 7173c5b587759..f09d655b07100 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -36,7 +36,10 @@ from .activation import HSigmoid #DEFINE_ALIAS from .common import BilinearTensorProduct #DEFINE_ALIAS from .common import Pool2D #DEFINE_ALIAS -from .common import Pad2D #DEFINE_ALIAS +from .common import Pad #DEFINE_ALIAS +from .common import ReflectionPad1d #DEFINE_ALIAS +from .common import ReplicationPad1d #DEFINE_ALIAS +from .common import ConstantPad1d #DEFINE_ALIAS from .common import Embedding #DEFINE_ALIAS from .common import Linear #DEFINE_ALIAS from .common import Flatten #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 45259bea49d42..e98eafebbc498 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -22,8 +22,8 @@ from .. import functional as F __all__ = [ - 'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample', - 'Pad2D' + 'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample', 'Pad', + 'ReflectionPad1d', 'ReplicationPad1d', 'ConstantPad1d' ] @@ -254,30 +254,34 @@ def forward(self, input): return out -class Pad2D(layers.Layer): +class Pad(layers.Layer): """ - :alias_main: paddle.nn.Pad2D - :alias: paddle.nn.Pad2D,paddle.nn.layer.Pad2D,paddle.nn.layer.common.Pad2D + :alias_main: paddle.nn.Pad + :alias: paddle.nn.Pad - This interface is used to construct a callable object of the ``Pad2D`` class. - The Pad2D layer pads the input tensor boundaries according to 'paddings' and 'mode'. + This interface is used to construct a callable object of the ``Pad`` class. + The Pad layer pads the input tensor boundaries according to 'paddings' and 'mode'. If mode is 'reflect', paddings[0] and paddings[1] must be no greater - than height-1. And the width dimension has the same condition. + than width-1. The height and depth dimensions have the same condition. Parameters: - paddings (int | List[int32]): The padding size. If padding is a int, uses the same - padding in all boundaries, if padding is a List, it must contain four integers, - (padding_top, padding_bottom, padding_left, padding_right). - Default is [0, 0, 0, 0]. - mode (str): Three modes: 'constant' (default), 'reflect', 'edge' . - When in 'constant' mode, this op uses a constant value to pad the input tensor. - When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor. - When in 'edge' mode, uses input boundaries to pad the input tensor. - Default is 'constant' - pad_value (float32): The value to fill the padded areas in 'constant' mode . Default is 0.0 - data_format (str): An string from: "NHWC", "NCHW". Specify the data format of - the input data. - Default is "NCHW" + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. 1. If input dimension is 3, then the pad has the form (pad_left, + pad_right). 2. If the input dimension is 4, then the pad has the form (pad_left, pad_right, + pad_top, pad_bottom). 3. If the input dimension is 5, then the pad has the form + (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). Default is [0, 0, 0, 0]. + mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. + When in 'constant' mode, this op uses a constant value to pad the input tensor. + When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor. + When in 'replicate' mode, uses input boundaries to pad the input tensor. + When in 'circular' mode, uses circular input to pad the input tensor. + Default is 'constant' + value (float32): The value to fill the padded areas in 'constant' mode . Default is 0.0 + data_format (str): An string from: "NCL", "NLC", NHWC", "NCHW", "NCDHW", "NDHWC". Specify the data format of + the input data. + Default is "NCHW" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: None @@ -285,30 +289,41 @@ class Pad2D(layers.Layer): Examples: .. code-block:: text - Input = [[[[1., 2., 3.], - [4., 5., 6.]]]] + Input = [[[[[1., 2., 3.], + [4., 5., 6.]]]]] Case 0: - paddings = [0, 1, 2, 3], + pad = [2, 2, 1, 1, 0, 0], mode = 'constant' pad_value = 0 - Out = [[[[0., 0., 1., 2., 3., 0., 0., 0.], - [0., 0., 4., 5., 6., 0., 0., 0.], - [0., 0., 0., 0., 0., 0., 0., 0.]]]] + Out = [[[[[0. 0. 0. 0. 0. 0. 0.] + [0. 0. 1. 2. 3. 0. 0.] + [0. 0. 4. 5. 6. 0. 0.] + [0. 0. 0. 0. 0. 0. 0.]]]]] Case 1: - paddings = [0, 1, 2, 1], + pad = [2, 2, 1, 1, 0, 0], mode = 'reflect' - Out = [[[[3., 2., 1., 2., 3., 2.], - [6., 5., 4., 5., 6., 5.], - [3., 2., 1., 2., 3., 2.]]]] + Out = [[[[[6. 5. 4. 5. 6. 5. 4.] + [3. 2. 1. 2. 3. 2. 1.] + [6. 5. 4. 5. 6. 5. 4.] + [3. 2. 1. 2. 3. 2. 1.]]]]] Case 2: - paddings = [0, 1, 2, 1], - mode = 'edge' - Out = [[[[1., 1., 1., 2., 3., 3.], - [4., 4., 4., 5., 6., 6.], - [4., 4., 4., 5., 6., 6.]]]] + pad = [2, 2, 1, 1, 0, 0], + mode = 'replicate' + Out = [[[[[1. 1. 1. 2. 3. 3. 3.] + [1. 1. 1. 2. 3. 3. 3.] + [4. 4. 4. 5. 6. 6. 6.] + [4. 4. 4. 5. 6. 6. 6.]]]]] + + Case 3: + pad = [2, 2, 1, 1, 0, 0], + mode = 'circular' + Out = [[[[[5. 6. 4. 5. 6. 4. 5.] + [2. 3. 1. 2. 3. 1. 2.] + [5. 6. 4. 5. 6. 4. 5.] + [2. 3. 1. 2. 3. 1. 2.]]]]] Code Examples: .. code-block:: python @@ -316,29 +331,222 @@ class Pad2D(layers.Layer): import paddle.fluid as fluid import paddle.nn as nn import numpy as np - data = np.ones((2, 2, 2, 2)).astype('float32') - my_pad = nn.Pad2D(paddings=[1, 1, 1, 1]) + data = np.ones((1, 1, 2, 2)).astype('float32') + my_pad = nn.Pad(pad=[1, 1, 1, 1]) with fluid.dygraph.guard(): data = fluid.dygraph.to_variable(data) result = my_pad(data) + print(result.numpy()) + # [[[[0. 0. 0. 0.] + # [0. 1. 1. 0.] + # [0. 1. 1. 0.] + # [0. 0. 0. 0.]]]] """ def __init__(self, - paddings=0, + pad=[0, 0, 0, 0], mode='constant', - pad_value=0.0, - data_format="NCHW"): - super(Pad2D, self).__init__() + value=0.0, + data_format="NCHW", + name=None): + super(Pad, self).__init__() self._mode = mode - self._pad_value = pad_value + self._value = value self._data_format = data_format - self._paddings = [paddings] * 4 if isinstance(paddings, - int) else paddings + self._pad = pad + self._name = name def forward(self, input): - return F.pad2d( - input, - paddings=self._paddings, - mode=self._mode, - pad_value=self._pad_value, - data_format=self._data_format) + return F.pad(input, + pad=self._pad, + mode=self._mode, + value=self._value, + data_format=self._data_format, + name=self._name) + + +class ReflectionPad1d(layers.Layer): + """ + :alias_main: paddle.nn.ReflectionPad1d + :alias: paddle.nn.ReflectionPad1d + + This interface is used to construct a callable object of the ``ReflectionPad1d`` class. + Uses reflection of the input boundaries to pad the input tensor. + + Parameters: + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. + Default is "NCL" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + None + + Examples: + .. code-block:: text + + Input = [[[1., 2., 3.], + [4., 5., 6.]]] + pad = [1, 2], + Out = [[[2. 1. 2. 3. 2. 1.] + [5. 4. 5. 6. 5. 4.]]] + + Code Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn as nn + import numpy as np + input_shape = (1, 2, 3) + pad = [1, 2] + data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + my_pad = nn.ReflectionPad1d(pad=pad) + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + result = my_pad(data) + print(result.numpy()) + # [[[2. 1. 2. 3. 2. 1.] + # [5. 4. 5. 6. 5. 4.]]] + """ + + def __init__(self, pad=[0, 0], data_format="NCL", name=None): + super(ReflectionPad1d, self).__init__() + self._mode = "reflect" + self._data_format = data_format + self._pad = pad + self._name = name + + def forward(self, input): + return F.pad(input, + pad=self._pad, + mode=self._mode, + data_format=self._data_format, + name=self._name) + + +class ReplicationPad1d(layers.Layer): + """ + :alias_main: paddle.nn.ReplicationPad1d + :alias: paddle.nn.ReplicationPad1d + + This interface is used to construct a callable object of the ``ReplicationPad1d`` class. + Uses input boundaries to pad the input tensor. + + Parameters: + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. + Default is "NCL" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + None + + Examples: + .. code-block:: text + + Input = [[[1., 2., 3.], + [4., 5., 6.]]] + pad = [1, 2], + Out = [[[2. 1. 2. 3. 2. 1.] + [5. 4. 5. 6. 5. 4.]]] + + Code Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn as nn + import numpy as np + input_shape = (1, 2, 3) + pad = [1, 2] + data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + my_pad = nn.ReplicationPad1d(pad=pad) + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + result = my_pad(data) + print(result.numpy()) + # [[[1. 1. 2. 3. 3. 3.] + # [1. 4. 5. 6. 6. 6.]]] + """ + + def __init__(self, pad=[0, 0], data_format="NCL", name=None): + super(ReplicationPad1d, self).__init__() + self._mode = "replicate" + self._data_format = data_format + self._pad = pad + self._name = name + + def forward(self, input): + return F.pad(input, + pad=self._pad, + mode=self._mode, + data_format=self._data_format, + name=self._name) + + +class ConstantPad1d(layers.Layer): + """ + :alias_main: paddle.nn.ConstantPad1d + :alias: paddle.nn.ConstantPad1d + + This interface is used to construct a callable object of the ``ConstantPad1d`` class. + Uses a constant value to pad the input tensor. + + Parameters: + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. + Default is "NCL" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + None + + Examples: + .. code-block:: text + + Input = [[[1., 2., 3.], + [4., 5., 6.]]] + pad = [1, 2], + value = 0.0 + Out = [[[0. 1. 2. 3. 0. 0.] + [0. 4. 5. 6. 0. 0.]]] + + Code Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn as nn + import numpy as np + input_shape = (1, 2, 3) + pad = [1, 2] + value = 0.0 + data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + my_pad = nn.ConstantPad1d(pad=pad, value=value) + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + result = my_pad(data) + print(result.numpy()) + # [[[0. 1. 2. 3. 0. 0.] + # [0. 4. 5. 6. 0. 0.]]] + """ + + def __init__(self, pad=[0, 0], value=0.0, data_format="NCL", name=None): + super(ConstantPad1d, self).__init__() + self._mode = "constant" + self._data_format = data_format + self._pad = pad + self._value = value + self._name = name + + def forward(self, input): + return F.pad(input, + pad=self._pad, + mode=self._mode, + value=self._value, + data_format=self._data_format, + name=self._name) From 2c09d75fc8e0a6aac4bf0b280ddd252b54827366 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 11 Aug 2020 06:12:06 +0000 Subject: [PATCH 03/15] test=develop, add pad op and apis --- paddle/fluid/operators/pad3d_op.cc | 1011 +++++++++++++++++ paddle/fluid/operators/pad3d_op.cu | 793 +++++++++++++ .../fluid/tests/unittests/test_pad3d_op.py | 387 +++++++ 3 files changed, 2191 insertions(+) create mode 100644 paddle/fluid/operators/pad3d_op.cc create mode 100644 paddle/fluid/operators/pad3d_op.cu create mode 100644 python/paddle/fluid/tests/unittests/test_pad3d_op.py diff --git a/paddle/fluid/operators/pad3d_op.cc b/paddle/fluid/operators/pad3d_op.cc new file mode 100644 index 0000000000000..73d925fd6ed9a --- /dev/null +++ b/paddle/fluid/operators/pad3d_op.cc @@ -0,0 +1,1011 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +void Pad3DConstNCDHW(const T* in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, const int pad_left, + T value, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + out_data[out_d * out_height * out_width + out_h * out_width + + out_w] = + (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width) + ? value + : in_data[in_d * in_height * in_width + in_h * in_width + + in_w]; + } + } + } + in_data += in_depth * in_height * in_width; + out_data += out_depth * out_height * out_width; + } + } +} + +template +void Pad3DConstNDHWC(const T* in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, const int pad_left, + T value, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * + channels; + if (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width) { + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = value; + } + } else { + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * + channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } + } + } + } + } + in_data += in_depth * in_height * in_width * channels; + out_data += out_depth * out_height * out_width * channels; + } +} + +template +void Pad3DReflectNCDHW(const T* in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); // reflect by 0 + in_d = + std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth + in_h = std::max(in_h, -in_h); // reflect by 0 + in_h = std::min(in_h, + 2 * in_height - in_h - 2); // reflect by in_height + in_w = std::max(in_w, -in_w); // reflect by 0 + in_w = + std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width + + out_data[out_d * out_height * out_width + out_h * out_width + + out_w] = + in_data[in_d * in_height * in_width + in_h * in_width + in_w]; + } + } + } + in_data += in_depth * in_height * in_width; + out_data += out_depth * out_height * out_width; + } + } +} + +template +void Pad3DReflectNDHWC(const T* in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); + in_d = std::min(in_d, 2 * in_depth - in_d - 2); + in_h = std::max(in_h, -in_h); + in_h = std::min(in_h, 2 * in_height - in_h - 2); + in_w = std::max(in_w, -in_w); + in_w = std::min(in_w, 2 * in_width - in_w - 2); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * + channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } + } + } + } + in_data += in_depth * in_height * in_width * channels; + out_data += out_depth * out_height * out_width * channels; + } +} + +template +void Pad3DReplicateNCDHW(const T* in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + out_data[out_d * out_height * out_width + out_h * out_width + + out_w] = + in_data[in_d * in_height * in_width + in_h * in_width + in_w]; + } + } + } + in_data += in_depth * in_height * in_width; + out_data += out_depth * out_height * out_width; + } + } +} + +template +void Pad3DReplicateNDHWC(const T* in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * + channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } + } + } + } + in_data += in_depth * in_height * in_width * channels; + out_data += out_depth * out_height * out_width * channels; + } +} + +template +void Pad3DCircularNCDHW(const T* in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + out_data[out_d * out_height * out_width + out_h * out_width + + out_w] = + in_data[in_d * in_height * in_width + in_h * in_width + in_w]; + } + } + } + in_data += in_depth * in_height * in_width; + out_data += out_depth * out_height * out_width; + } + } +} + +template +void Pad3DCircularNDHWC(const T* in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * + channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } + } + } + } + in_data += in_depth * in_height * in_width * channels; + out_data += out_depth * out_height * out_width * channels; + } +} + +template +void Pad3DGradConstNCDHW(T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width)) { + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] = + d_out_data[out_d * out_height * out_width + + out_h * out_width + out_w]; + } + } + } + } + d_in_data += in_depth * in_height * in_width; + d_out_data += out_depth * out_height * out_width; + } + } +} + +template +void Pad3DGradConstNDHWC(T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * + channels; + if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width)) { + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * + channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] = d_out_data[out_index + c]; + } + } + } + } + } + d_in_data += in_depth * in_height * in_width * channels; + d_out_data += out_depth * out_height * out_width * channels; + } +} + +template +void Pad3DGradReflectNCDHW(T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); // reflect by 0 + in_d = + std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth + in_h = std::max(in_h, -in_h); // reflect by 0 + in_h = std::min(in_h, + 2 * in_height - in_h - 2); // reflect by in_height + in_w = std::max(in_w, -in_w); // reflect by 0 + in_w = + std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width + + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += + d_out_data[out_d * out_height * out_width + out_h * out_width + + out_w]; + } + } + } + d_in_data += in_depth * in_height * in_width; + d_out_data += out_depth * out_height * out_width; + } + } +} + +template +void Pad3DGradReflectNDHWC(T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); + in_d = std::min(in_d, 2 * in_depth - in_d - 2); + in_h = std::max(in_h, -in_h); + in_h = std::min(in_h, 2 * in_height - in_h - 2); + in_w = std::max(in_w, -in_w); + in_w = std::min(in_w, 2 * in_width - in_w - 2); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * + channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; + } + } + } + } + d_in_data += in_depth * in_height * in_width * channels; + d_out_data += out_depth * out_height * out_width * channels; + } +} + +template +void Pad3DGradReplicateNCDHW(T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += + d_out_data[out_d * out_height * out_width + out_h * out_width + + out_w]; + } + } + } + d_in_data += in_depth * in_height * in_width; + d_out_data += out_depth * out_height * out_width; + } + } +} + +template +void Pad3DGradReplicateNDHWC(T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * + channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; + } + } + } + } + d_in_data += in_depth * in_height * in_width * channels; + d_out_data += out_depth * out_height * out_width * channels; + } +} + +template +void Pad3DGradCircularNCDHW(T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += + d_out_data[out_d * out_height * out_width + out_h * out_width + + out_w]; + } + } + } + d_in_data += in_depth * in_height * in_width; + d_out_data += out_depth * out_height * out_width; + } + } +} + +template +void Pad3DGradCircularNDHWC(T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int out_d = 0; out_d < out_depth; ++out_d) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * + channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; + } + } + } + } + d_in_data += in_depth * in_height * in_width * channels; + d_out_data += out_depth * out_height * out_width * channels; + } +} + +static inline void GetPaddings(int* paddings, + const framework::ExecutionContext& context) { + auto* paddings_t = context.Input("Paddings"); + if (paddings_t) { + auto paddings_data = paddings_t->data(); + paddings[0] = paddings_data[0]; + paddings[1] = paddings_data[1]; + paddings[2] = paddings_data[2]; + paddings[3] = paddings_data[3]; + paddings[4] = paddings_data[4]; + paddings[5] = paddings_data[5]; + } else { + auto pads = context.Attr>("paddings"); + std::copy(pads.begin(), pads.end(), paddings); + } +} + +template +class Pad3dCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int pads[6]; + GetPaddings(pads, context); + auto mode = context.Attr("mode"); + auto data_format = context.Attr("data_format"); + T value = static_cast(context.Attr("value")); + + auto* x = context.Input("X"); + auto in_dims = x->dims(); + const T* in_data = x->data(); + + auto* out = context.Output("Out"); + if (data_format == "NCDHW") { + out->Resize({in_dims[0], in_dims[1], in_dims[2] + pads[4] + pads[5], + in_dims[3] + pads[2] + pads[3], + in_dims[4] + pads[0] + pads[1]}); + } else { + out->Resize({in_dims[0], in_dims[1] + pads[4] + pads[5], + in_dims[2] + pads[2] + pads[3], + in_dims[3] + pads[0] + pads[1], in_dims[4]}); + } + auto out_dims = out->dims(); + T* out_data = out->mutable_data(context.GetPlace()); + + int channels = in_dims[1]; + int in_depth = in_dims[2]; + int in_height = in_dims[3]; + int in_width = in_dims[4]; + int out_depth = out_dims[2]; + int out_height = out_dims[3]; + int out_width = out_dims[4]; + if (data_format == "NDHWC") { + channels = in_dims[4]; + in_depth = in_dims[1]; + in_height = in_dims[2]; + in_width = in_dims[3]; + out_depth = out_dims[1]; + out_height = out_dims[2]; + out_width = out_dims[3]; + } + + if (mode == "reflect") { + PADDLE_ENFORCE_GT(in_depth, pads[4], + platform::errors::InvalidArgument( + "The depth of Input(X)'s dimension should be " + "greater than pad_front" + " in reflect mode" + ", but received depth(%d) and pad_front(%d).", + in_depth, pads[4])); + PADDLE_ENFORCE_GT(in_depth, pads[5], + platform::errors::InvalidArgument( + "The depth of Input(X)'s dimension should be " + "greater than pad_back" + " in reflect mode" + ", but received depth(%d) and pad_back(%d).", + in_depth, pads[5])); + + PADDLE_ENFORCE_GT(in_height, pads[2], + platform::errors::InvalidArgument( + "The height of Input(X)'s dimension should be " + "greater than pad_top" + " in reflect mode" + ", but received depth(%d) and pad_top(%d).", + in_height, pads[2])); + PADDLE_ENFORCE_GT(in_height, pads[3], + platform::errors::InvalidArgument( + "The height of Input(X)'s dimension should be " + "greater than pad_bottom" + " in reflect mode" + ", but received depth(%d) and pad_bottom(%d).", + in_height, pads[3])); + + PADDLE_ENFORCE_GT(in_width, pads[0], + platform::errors::InvalidArgument( + "The width of Input(X)'s dimension should be " + "greater than pad_left" + " in reflect mode" + ", but received depth(%d) and pad_left(%d).", + in_width, pads[0])); + PADDLE_ENFORCE_GT(in_width, pads[1], + platform::errors::InvalidArgument( + "The width of Input(X)'s dimension should be " + "greater than pad_right" + " in reflect mode" + ", but received depth(%d) and pad_right(%d).", + in_width, pads[1])); + } + + const int pad_left = pads[0]; + const int pad_top = pads[2]; + const int pad_front = pads[4]; + const int num = in_dims[0]; + if (data_format == "NCDHW") { + if (mode == "reflect") { + Pad3DReflectNCDHW(in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, + pad_left, out_data); + } else if (mode == "replicate") { + Pad3DReplicateNCDHW(in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, out_data); + } else if (mode == "circular") { + Pad3DCircularNCDHW(in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, out_data); + } else if (mode == "constant") { + Pad3DConstNCDHW(in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, + pad_left, value, out_data); + } + } else { + if (mode == "reflect") { + Pad3DReflectNDHWC(in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, + pad_left, out_data); + } else if (mode == "replicate") { + Pad3DReplicateNDHWC(in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, out_data); + } else if (mode == "circular") { + Pad3DCircularNDHWC(in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, out_data); + } else { + Pad3DConstNDHWC(in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, + pad_left, value, out_data); + } + } + } +}; + +template +class Pad3dGradCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int pads[6]; + GetPaddings(pads, context); + auto mode = context.Attr("mode"); + auto data_format = context.Attr("data_format"); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_in = context.Output(framework::GradVarName("X")); + auto d_in_dims = d_in->dims(); + auto d_out_dims = d_out->dims(); + const T* d_out_data = d_out->data(); + T* d_in_data = d_in->mutable_data(context.GetPlace()); + math::SetConstant set_zero; + set_zero(context.template device_context(), + d_in, static_cast(0)); + const int pad_left = pads[0]; + const int pad_top = pads[2]; + const int pad_front = pads[4]; + const int num = d_in_dims[0]; + if (data_format == "NCDHW") { + const int channels = d_in_dims[1]; + const int in_depth = d_in_dims[2]; + const int in_height = d_in_dims[3]; + const int in_width = d_in_dims[4]; + const int out_depth = d_out_dims[2]; + const int out_height = d_out_dims[3]; + const int out_width = d_out_dims[4]; + if (mode == "reflect") { + Pad3DGradReflectNCDHW(d_in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, d_out_data); + } else if (mode == "replicate") { + Pad3DGradReplicateNCDHW(d_in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, d_out_data); + } else if (mode == "circular") { + Pad3DGradCircularNCDHW(d_in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, d_out_data); + } else { + Pad3DGradConstNCDHW(d_in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, d_out_data); + } + } else { + const int channels = d_in_dims[4]; + const int in_depth = d_in_dims[1]; + const int in_height = d_in_dims[2]; + const int in_width = d_in_dims[3]; + const int out_depth = d_out_dims[1]; + const int out_height = d_out_dims[2]; + const int out_width = d_out_dims[3]; + if (mode == "reflect") { + Pad3DGradReflectNDHWC(d_in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, d_out_data); + } else if (mode == "replicate") { + Pad3DGradReplicateNDHWC(d_in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, d_out_data); + } else if (mode == "circular") { + Pad3DGradCircularNDHWC(d_in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, d_out_data); + } else { + Pad3DGradConstNDHWC(d_in_data, num, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, + pad_front, pad_top, pad_left, d_out_data); + } + } + } +}; + +class Pad3dOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad3d"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Pad3d"); + + auto x_dim = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dim.size(), 5, + platform::errors::InvalidArgument( + "The size of Input(X)'s dimension should be equal to " + "5, but received %d. ", + x_dim.size())); + + std::vector out_dims(x_dim.size()); + auto data_format = ctx->Attrs().Get("data_format"); + out_dims[0] = x_dim[0]; + if (ctx->HasInput("Paddings")) { + auto paddings_dim = ctx->GetInputDim("Paddings"); + PADDLE_ENFORCE_EQ(paddings_dim.size(), 1, + platform::errors::InvalidArgument( + "Size of Input(Paddings)'s dimension should be " + "equal to 1, but received %d.", + paddings_dim.size())); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(paddings_dim[0], 6, + platform::errors::InvalidArgument( + "Shape of Input(Paddings) should be equal to " + "[6], but received [%d].", + paddings_dim[0])); + } + out_dims[1] = x_dim[1]; + out_dims[2] = x_dim[2]; + out_dims[3] = x_dim[3]; + } else { + auto paddings = ctx->Attrs().Get>("paddings"); + PADDLE_ENFORCE_EQ( + paddings.size(), 6, + platform::errors::InvalidArgument( + "Size of paddings should be equal to 4, but received %d.", + static_cast(paddings.size()))); + if (data_format == "NCDHW") { + out_dims[1] = x_dim[1]; // channel + out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0)) + ? x_dim[2] + : (x_dim[2] + paddings[4] + paddings[5]); // depth + + out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0)) + ? x_dim[3] + : (x_dim[3] + paddings[2] + paddings[3]); // height + + out_dims[4] = ((!ctx->IsRuntime()) && (x_dim[4] < 0)) + ? x_dim[4] + : (x_dim[4] + paddings[0] + paddings[1]); // width + } else { // NDHWC + out_dims[4] = x_dim[4]; // channel + + out_dims[1] = ((!ctx->IsRuntime()) && (x_dim[1] < 0)) + ? x_dim[1] + : (x_dim[1] + paddings[4] + paddings[5]); // depth + out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0)) + ? x_dim[2] + : (x_dim[2] + paddings[2] + paddings[3]); // height + out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0)) + ? x_dim[3] + : (x_dim[3] + paddings[0] + paddings[1]); // width + } + } + + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class Pad3dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input of pad3d op. " + "The input should be a 5-D tensor with formate NCDHW or NDHWC."); + AddOutput("Out", + "The output of pad3d op. " + "A tensor with the same shape as X."); + AddInput("Paddings", + "A 1-D tensor to describe the padding rules." + "paddings=[0, 1, 2, 3, 4, 5] means " + "padding 0 column to left, 1 column to right, " + "2 row to top, 3 row to bottom, 4 depth to front " + "and 5 depth to back. Size of paddings must be 6.") + .AsDispensable(); + AddAttr>( + "paddings", + "(vector) " + "A list to describe the padding rules." + "paddings=[0, 1, 2, 3, 4, 5] means " + "padding 0 column to left, 1 column to right, " + "2 row to top, 3 row to bottom, 4 depth to front " + "and 5 depth to back. Size of paddings must be 6."); + AddAttr("value", + "(float, default 0.0) " + "The value to fill the padded areas in constant mode.") + .SetDefault(0.0f); + AddAttr( + "mode", + "(string, default constant) " + "Four modes: constant(default), reflect, replicate, circular.") + .SetDefault("constant"); + AddAttr( + "data_format", + "(string, default NCDHW) Only used in " + "An optional string from: \"NDHWC\", \"NCDHW\". " + "Defaults to \"NDHWC\". Specify the data format of the input data.") + .SetDefault("NCDHW"); + AddComment(R"DOC( +Pad3d Operator. +Pad 3-d images according to 'paddings' and 'mode'. +If mode is 'reflect', paddings[0] and paddings[1] must be no greater +than width-1. The height and depth dimension have the same condition. + +Given that X is a channel of image from input: + +X = [[[[[1, 2, 3], + [4, 5, 6]]]]] + +Case 0: + +paddings = [2, 2, 1, 1, 0, 0], +mode = 'constant' +pad_value = 0 + +Out = [[[[[0. 0. 0. 0. 0. 0. 0.] + [0. 0. 1. 2. 3. 0. 0.] + [0. 0. 4. 5. 6. 0. 0.] + [0. 0. 0. 0. 0. 0. 0.]]]]] + +Case 1: + +paddings = [2, 2, 1, 1, 0, 0], +mode = 'reflect' + +Out = [[[[[6. 5. 4. 5. 6. 5. 4.] + [3. 2. 1. 2. 3. 2. 1.] + [6. 5. 4. 5. 6. 5. 4.] + [3. 2. 1. 2. 3. 2. 1.]]]]] + +Case 2: + +paddings = [2, 2, 1, 1, 0, 0], +mode = 'replicate' + +Out = [[[[[1. 1. 1. 2. 3. 3. 3.] + [1. 1. 1. 2. 3. 3. 3.] + [4. 4. 4. 5. 6. 6. 6.] + [4. 4. 4. 5. 6. 6. 6.]]]]] + +Case 3: + +paddings = [2, 2, 1, 1, 0, 0], +mode = 'circular' + +Out = [[[[[5. 6. 4. 5. 6. 4. 5.] + [2. 3. 1. 2. 3. 1. 2.] + [5. 6. 4. 5. 6. 4. 5.] + [2. 3. 1. 2. 3. 1. 2.]]]]] + +)DOC"); + } +}; + +class Pad3dOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad3d@Grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "Pad3d@Grad"); + + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +template +class Pad3dOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr bind) const override { + bind->SetInput("X", this->Input("X")); + if (this->HasInput("Paddings")) { + bind->SetInput("Paddings", this->Input("Paddings")); + } + bind->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + bind->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + bind->SetAttrMap(this->Attrs()); + bind->SetType("pad3d_grad"); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad3dOpGradNoNeedBufferVarsInferer, "X"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(pad3d, ops::Pad3dOp, ops::Pad3dOpMaker, + ops::Pad3dOpGradMaker, + ops::Pad3dOpGradMaker); +REGISTER_OPERATOR(pad3d_grad, ops::Pad3dOpGrad, + ops::Pad3dOpGradNoNeedBufferVarsInferer); +REGISTER_OP_CPU_KERNEL(pad3d, ops::Pad3dCPUKernel, + ops::Pad3dCPUKernel, ops::Pad3dCPUKernel, + ops::Pad3dCPUKernel); +REGISTER_OP_CPU_KERNEL(pad3d_grad, ops::Pad3dGradCPUKernel, + ops::Pad3dGradCPUKernel); diff --git a/paddle/fluid/operators/pad3d_op.cu b/paddle/fluid/operators/pad3d_op.cu new file mode 100644 index 0000000000000..3ab811f4905a0 --- /dev/null +++ b/paddle/fluid/operators/pad3d_op.cu @@ -0,0 +1,793 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +using framework::Tensor; + +template +__global__ void Pad3DConstNCDHW(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T value, T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + out_data[index] = + (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width) + ? value + : in_data[nc * in_depth * in_height * in_width + + in_d * in_height * in_width + in_h * in_width + in_w]; + } +} + +template +__global__ void Pad3DConstNDHWC(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T value, T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + const int in_d = out_d - pad_front; + const int in_h = out_h - pad_top; + const int in_w = out_w - pad_left; + + out_data[index] = + (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width) + ? value + : in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c]; + } +} + +template +__global__ void Pad3DReflectNCDHW(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = max(in_d, -in_d); // reflect by 0 + in_d = min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth + in_h = max(in_h, -in_h); // reflect by 0 + in_h = min(in_h, 2 * in_height - in_h - 2); // reflect by in_height + in_w = max(in_w, -in_w); // reflect by 0 + in_w = min(in_w, 2 * in_width - in_w - 2); // reflect by in_width + out_data[index] = + in_data[(nc * in_depth * in_height + in_d * in_height + in_h) * + in_width + + in_w]; + } +} + +template +__global__ void Pad3DReflectNDHWC(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = max(in_d, -in_d); + in_d = min(in_d, 2 * in_depth - in_d - 2); + in_h = max(in_h, -in_h); + in_h = min(in_h, 2 * in_height - in_h - 2); + in_w = max(in_w, -in_w); + in_w = min(in_w, 2 * in_width - in_w - 2); + + out_data[index] = in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c]; + } +} + +template +__global__ void Pad3DReplicateNCDHW(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); + int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + + out_data[index] = + in_data[(nc * in_depth * in_height + in_d * in_height + in_h) * + in_width + + in_w]; + } +} + +template +__global__ void Pad3DReplicateNDHWC(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + + int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); + int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + + out_data[index] = in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c]; + } +} + +template +__global__ void Pad3DCircularNCDHW(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + out_data[index] = + in_data[(nc * in_depth * in_height + in_d * in_height + in_h) * + in_width + + in_w]; + } +} + +template +__global__ void Pad3DCircularNDHWC(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, T* out_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + out_data[index] = in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c]; + } +} + +template +__global__ void Pad3DGradConstNCDHW(const int in_size, T* d_in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + CUDA_KERNEL_LOOP(in_index, in_size) { + const int in_w = in_index % in_width; + + int nc = in_index / in_width; + const int in_h = nc % in_height; + + nc /= in_height; + const int in_d = nc % in_depth; + + nc /= in_depth; + + const int out_d = in_d + pad_front; + const int out_h = in_h + pad_top; + const int out_w = in_w + pad_left; + d_in_data[in_index] = + d_out_data[nc * out_depth * out_height * out_width + + out_d * out_height * out_width + out_h * out_width + out_w]; + } +} + +template +__global__ void Pad3DGradConstNDHWC(const int in_size, T* d_in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + CUDA_KERNEL_LOOP(in_index, in_size) { + const int c = in_index % channels; + int n = in_index / channels; + + const int in_w = n % in_width; + n /= in_width; + + const int in_h = n % in_height; + n /= in_height; + + const int in_d = n % in_depth; + n /= in_depth; + + const int out_d = in_d + pad_front; + const int out_h = in_h + pad_top; + const int out_w = in_w + pad_left; + + d_in_data[in_index] = + d_out_data[n * out_depth * out_height * out_width * channels + + out_d * out_height * out_width * channels + + out_h * out_width * channels + out_w * channels + c]; + } +} + +template +__global__ void Pad3DGradReflectNCDHW(const int out_size, T* d_in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + int nc = out_index / out_width; + const int out_w = out_index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = max(in_d, -in_d); + in_h = max(in_h, -in_h); + in_w = max(in_w, -in_w); + + in_d = min(in_d, 2 * in_depth - in_d - 2); + in_h = min(in_h, 2 * in_height - in_h - 2); + in_w = min(in_w, 2 * in_width - in_w - 2); + + platform::CudaAtomicAdd( + &d_in_data[nc * in_depth * in_height * in_width + + in_d * in_height * in_width + in_h * in_width + in_w], + d_out_data[out_index]); + } +} + +template +__global__ void Pad3DGradReflectNDHWC(const int out_size, T* d_in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + const int c = out_index % channels; + int n = out_index / channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = max(in_d, -in_d); + in_h = max(in_h, -in_h); + in_w = max(in_w, -in_w); + + in_d = min(in_d, in_depth * 2 - in_d - 2); + in_h = min(in_h, in_height * 2 - in_h - 2); + in_w = min(in_w, in_width * 2 - in_w - 2); + platform::CudaAtomicAdd( + &d_in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c], + d_out_data[out_index]); + } +} + +template +__global__ void Pad3DGradReplicateNCDHW( + const int out_size, T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, const int in_width, + const int out_depth, const int out_height, const int out_width, + const int pad_front, const int pad_top, const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + int nc = out_index / out_width; + const int out_w = out_index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + const int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); + const int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + const int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + + platform::CudaAtomicAdd( + &d_in_data[nc * in_depth * in_height * in_width + + in_d * in_height * in_width + in_h * in_width + in_w], + d_out_data[out_index]); + } +} + +template +__global__ void Pad3DGradReplicateNDHWC( + const int out_size, T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, const int in_width, + const int out_depth, const int out_height, const int out_width, + const int pad_front, const int pad_top, const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + const int c = out_index % channels; + int n = out_index / channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + + const int in_d = min(in_depth - 1, max(out_d - pad_front, 0)); + const int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + const int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + + platform::CudaAtomicAdd( + &d_in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c], + d_out_data[out_index]); + } +} + +template +__global__ void Pad3DGradCircularNCDHW(const int out_size, T* d_in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, + const int out_width, const int pad_front, + const int pad_top, const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + int nc = out_index / out_width; + const int out_w = out_index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int out_d = nc % out_depth; + nc /= out_depth; + + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + platform::CudaAtomicAdd( + &d_in_data[nc * in_depth * in_height * in_width + + in_d * in_height * in_width + in_h * in_width + in_w], + d_out_data[out_index]); + } +} + +template +__global__ void Pad3DGradCircularNDHWC(const int out_size, T* d_in_data, + const int num, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, + const int out_width, const int pad_front, + const int pad_top, const int pad_left, + const T* d_out_data) { + CUDA_KERNEL_LOOP(out_index, out_size) { + const int c = out_index % channels; + int n = out_index / channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int out_d = n % out_depth; + n /= out_depth; + + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + platform::CudaAtomicAdd( + &d_in_data[n * in_depth * in_height * in_width * channels + + in_d * in_height * in_width * channels + + in_h * in_width * channels + in_w * channels + c], + d_out_data[out_index]); + } +} + +static inline void GetPaddings(int* paddings, + const framework::ExecutionContext& context) { + auto* paddings_data = context.Input("Paddings"); + if (paddings_data) { + Tensor pads; + framework::TensorCopySync(*paddings_data, platform::CPUPlace(), &pads); + auto pads_data = pads.data(); + paddings[0] = pads_data[0]; + paddings[1] = pads_data[1]; + paddings[2] = pads_data[2]; + paddings[3] = pads_data[3]; + paddings[4] = pads_data[4]; + paddings[5] = pads_data[5]; + } else { + auto pads = context.Attr>("paddings"); + std::copy(pads.begin(), pads.end(), paddings); + } +} + +template +class Pad3dCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int pads[6]; + GetPaddings(pads, context); + auto mode = context.Attr("mode"); + auto data_format = context.Attr("data_format"); + T value = static_cast(context.Attr("value")); + + auto* x = context.Input("X"); + auto in_dims = x->dims(); + const T* in_data = x->data(); + auto* out = context.Output("Out"); + auto out_dims = out->dims(); + if (data_format == "NCDHW") { + out_dims[0] = in_dims[0]; + out_dims[1] = in_dims[1]; + out_dims[2] = in_dims[2] + pads[4] + pads[5]; + out_dims[3] = in_dims[3] + pads[2] + pads[3]; + out_dims[4] = in_dims[4] + pads[0] + pads[1]; + } else { + out_dims[0] = in_dims[0]; + out_dims[1] = in_dims[1] + pads[4] + pads[5]; + out_dims[2] = in_dims[2] + pads[2] + pads[3]; + out_dims[3] = in_dims[3] + pads[0] + pads[1]; + out_dims[4] = in_dims[4]; + } + T* out_data = out->mutable_data(out_dims, context.GetPlace()); + + int channels = in_dims[1]; + int in_depth = in_dims[2]; + int in_height = in_dims[3]; + int in_width = in_dims[4]; + int out_depth = out_dims[2]; + int out_height = out_dims[3]; + int out_width = out_dims[4]; + if (data_format == "NDHWC") { + channels = in_dims[4]; + in_depth = in_dims[1]; + in_height = in_dims[2]; + in_width = in_dims[3]; + out_depth = out_dims[1]; + out_height = out_dims[2]; + out_width = out_dims[3]; + } + + if (mode == "reflect") { + PADDLE_ENFORCE_GT(in_depth, pads[4], + platform::errors::InvalidArgument( + "The depth of Input(X)'s dimension should be " + "greater than pad_front" + " in reflect mode" + ", but received depth(%d) and pad_front(%d).", + in_depth, pads[4])); + PADDLE_ENFORCE_GT(in_depth, pads[5], + platform::errors::InvalidArgument( + "The depth of Input(X)'s dimension should be " + "greater than pad_back" + " in reflect mode" + ", but received depth(%d) and pad_back(%d).", + in_depth, pads[5])); + + PADDLE_ENFORCE_GT(in_height, pads[2], + platform::errors::InvalidArgument( + "The height of Input(X)'s dimension should be " + "greater than pad_top" + " in reflect mode" + ", but received depth(%d) and pad_top(%d).", + in_height, pads[2])); + PADDLE_ENFORCE_GT(in_height, pads[3], + platform::errors::InvalidArgument( + "The height of Input(X)'s dimension should be " + "greater than pad_bottom" + " in reflect mode" + ", but received depth(%d) and pad_bottom(%d).", + in_height, pads[3])); + + PADDLE_ENFORCE_GT(in_width, pads[0], + platform::errors::InvalidArgument( + "The width of Input(X)'s dimension should be " + "greater than pad_left" + " in reflect mode" + ", but received depth(%d) and pad_left(%d).", + in_width, pads[0])); + PADDLE_ENFORCE_GT(in_width, pads[1], + platform::errors::InvalidArgument( + "The width of Input(X)'s dimension should be " + "greater than pad_right" + " in reflect mode" + ", but received depth(%d) and pad_right(%d).", + in_width, pads[1])); + } + + const int pad_left = pads[0]; + const int pad_top = pads[2]; + const int pad_front = pads[4]; + const int num = in_dims[0]; + + auto stream = context.cuda_device_context().stream(); + int block = PADDLE_CUDA_NUM_THREADS; + const int out_size = out->numel(); + int grid = (out_size + block - 1) / block; + + if (data_format == "NCDHW") { + if (mode == "reflect") { + Pad3DReflectNCDHW<<>>( + out_size, in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + out_data); + } else if (mode == "replicate") { + Pad3DReplicateNCDHW<<>>( + out_size, in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + out_data); + } else if (mode == "circular") { + Pad3DCircularNCDHW<<>>( + out_size, in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + out_data); + } else { + Pad3DConstNCDHW<<>>( + out_size, in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + value, out_data); + } + } else { + if (mode == "reflect") { + Pad3DReflectNDHWC<<>>( + out_size, in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + out_data); + } else if (mode == "replicate") { + Pad3DReplicateNDHWC<<>>( + out_size, in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + out_data); + } else if (mode == "circular") { + Pad3DCircularNDHWC<<>>( + out_size, in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + out_data); + } else { + Pad3DConstNDHWC<<>>( + out_size, in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + value, out_data); + } + } + } +}; + +template +class Pad3dGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + int pads[6]; + GetPaddings(pads, context); + auto mode = context.Attr("mode"); + auto data_format = context.Attr("data_format"); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_in = context.Output(framework::GradVarName("X")); + auto d_in_dims = d_in->dims(); + auto d_out_dims = d_out->dims(); + const T* d_out_data = d_out->data(); + T* d_in_data = d_in->mutable_data(context.GetPlace()); + + math::SetConstant set_zero; + set_zero(context.template device_context(), + d_in, static_cast(0)); + + const int pad_left = pads[0]; + const int pad_top = pads[2]; + const int pad_front = pads[4]; + + const int num = d_in_dims[0]; + + auto stream = context.cuda_device_context().stream(); + int block = PADDLE_CUDA_NUM_THREADS; + const int out_size = d_out->numel(); + const int in_size = d_in->numel(); + int grid = (out_size + block - 1) / block; + + if (data_format == "NCDHW") { + const int channels = d_in_dims[1]; + const int in_depth = d_in_dims[2]; + const int in_height = d_in_dims[3]; + const int in_width = d_in_dims[4]; + const int out_depth = d_out_dims[2]; + const int out_height = d_out_dims[3]; + const int out_width = d_out_dims[4]; + + if (mode == "reflect") { + Pad3DGradReflectNCDHW<<>>( + out_size, d_in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + d_out_data); + } else if (mode == "replicate") { + Pad3DGradReplicateNCDHW<<>>( + out_size, d_in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + d_out_data); + } else if (mode == "circular") { + Pad3DGradCircularNCDHW<<>>( + out_size, d_in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + d_out_data); + } else { + grid = (in_size + block - 1) / block; + Pad3DGradConstNCDHW<<>>( + in_size, d_in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + d_out_data); + } + } else { + const int channels = d_in_dims[4]; + const int in_depth = d_in_dims[1]; + const int in_height = d_in_dims[2]; + const int in_width = d_in_dims[3]; + const int out_depth = d_out_dims[1]; + const int out_height = d_out_dims[2]; + const int out_width = d_out_dims[3]; + if (mode == "reflect") { + Pad3DGradReflectNDHWC<<>>( + out_size, d_in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + d_out_data); + } else if (mode == "replicate") { + Pad3DGradReplicateNDHWC<<>>( + out_size, d_in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + d_out_data); + } else if (mode == "circular") { + Pad3DGradCircularNDHWC<<>>( + out_size, d_in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + d_out_data); + } else { + grid = (in_size + block - 1) / block; + Pad3DGradConstNDHWC<<>>( + in_size, d_in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + d_out_data); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(pad3d, ops::Pad3dCUDAKernel, + ops::Pad3dCUDAKernel, + ops::Pad3dCUDAKernel, ops::Pad3dCUDAKernel, + ops::Pad3dCUDAKernel); +REGISTER_OP_CUDA_KERNEL(pad3d_grad, ops::Pad3dGradCUDAKernel, + ops::Pad3dGradCUDAKernel, + ops::Pad3dGradCUDAKernel); diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py new file mode 100644 index 0000000000000..4cf27184b4194 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -0,0 +1,387 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.fluid.dygraph as dg +import paddle.fluid.core as core + +from paddle.fluid import Program, program_guard + + +class TestPad3dOp(OpTest): + def setUp(self): + self.value = 0.0 + self.variable_paddings = False + self.initTestCase() + self.op_type = "pad3d" + self.inputs = {'X': np.random.random(self.shape).astype("float64")} + self.attrs = {} + if self.variable_paddings: + self.attrs['paddings'] = [] + self.inputs['Paddings'] = np.array(self.paddings).flatten().astype( + "int32") + else: + self.attrs['paddings'] = np.array(self.paddings).flatten().astype( + "int32") + self.attrs['value'] = self.value + self.attrs['mode'] = self.mode + self.attrs['data_format'] = self.data_format + if self.data_format == "NCDHW": + paddings = [ + (0, 0), + (0, 0), + (self.paddings[4], self.paddings[5]), + (self.paddings[2], self.paddings[3]), + (self.paddings[0], self.paddings[1]), + ] + else: + paddings = [ + (0, 0), + (self.paddings[4], self.paddings[5]), + (self.paddings[2], self.paddings[3]), + (self.paddings[0], self.paddings[1]), + (0, 0), + ] + if self.mode == "constant": + out = np.pad(self.inputs['X'], + paddings, + mode=self.mode, + constant_values=self.value) + elif self.mode == "reflect": + out = np.pad(self.inputs['X'], paddings, mode=self.mode) + elif self.mode == "replicate": + out = np.pad(self.inputs['X'], paddings, mode="edge") + elif self.mode == "circular": + out = np.pad(self.inputs['X'], paddings, mode="wrap") + self.outputs = {'Out': out} + print("[gry debug]outputs shape: ", out.shape) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 0, 0, 0, 0, 0] + self.mode = "constant" + self.data_format = "NCDHW" + self.pad_value = 0.0 + + +class TestCase1(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 2, 3, 4, 5] + self.mode = "constant" + self.data_format = "NCDHW" + self.value = 1.0 + + +class TestCase2(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [1, 1, 1, 1, 1, 1] + self.mode = "constant" + self.data_format = "NDHWC" + self.value = 1.0 + + +class TestCase3(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 1, 0, 2, 3] + self.mode = "reflect" + self.data_format = "NCDHW" + + +class TestCase4(TestPad3dOp): + def initTestCase(self): + self.shape = (4, 4, 4, 4, 4) + self.paddings = [0, 1, 2, 1, 2, 3] + self.mode = "reflect" + self.data_format = "NDHWC" + + +class TestCase5(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 2, 3, 2, 1] + self.mode = "replicate" + self.data_format = "NCDHW" + + +class TestCase6(TestPad3dOp): + def initTestCase(self): + self.shape = (4, 4, 4, 4, 4) + self.paddings = [5, 4, 2, 1, 2, 3] + self.mode = "replicate" + self.data_format = "NDHWC" + + +class TestCase7(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 2, 3, 2, 1] + self.mode = "circular" + self.data_format = "NCDHW" + + +class TestCase8(TestPad3dOp): + def initTestCase(self): + self.shape = (4, 4, 4, 4, 4) + self.paddings = [0, 1, 2, 1, 2, 3] + self.mode = "circular" + self.data_format = "NDHWC" + + +class TestCase9(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 2, 3, 3, 1] + self.mode = "reflect" + self.data_format = "NCDHW" + self.variable_paddings = True + + +class TestPad3dDygraph(unittest.TestCase): + def _get_numpy_out(self, input_data, pad, mode, value, data_format="NCDHW"): + if data_format == "NCDHW": + pad = [ + (0, 0), + (0, 0), + (pad[4], pad[5]), + (pad[2], pad[3]), + (pad[0], pad[1]), + ] + else: + pad = [ + (0, 0), + (pad[4], pad[5]), + (pad[2], pad[3]), + (pad[0], pad[1]), + (0, 0), + ] + + if mode == "constant": + out = np.pad(input_data, pad, mode=mode, constant_values=value) + elif mode == "reflect": + out = np.pad(input_data, pad, mode=mode) + elif mode == "replicate": + out = np.pad(input_data, pad, mode="edge") + elif mode == "circular": + out = np.pad(input_data, pad, mode="wrap") + + return out + + def test_dygraph(self): + + input_shape = (1, 2, 3, 4, 5) + pad = [1, 2, 1, 1, 3, 4] + mode = "constant" + value = 100 + input_data = np.random.rand(*input_shape).astype(np.float32) + np_out = self._get_numpy_out(input_data, pad, mode, value) + place = paddle.CPUPlace() + with dg.guard(place) as g: + input = dg.to_variable(input_data) + output = F.pad(input=input, pad=pad, mode=mode, value=value) + self.assertTrue(np.allclose(output.numpy(), np_out)) + + +class TestPadAPI(unittest.TestCase): + def _get_numpy_out(self, input_data, pad, mode, value, data_format="NCDHW"): + if data_format == "NCDHW": + pad = [ + (0, 0), + (0, 0), + (pad[4], pad[5]), + (pad[2], pad[3]), + (pad[0], pad[1]), + ] + else: + pad = [ + (0, 0), + (pad[4], pad[5]), + (pad[2], pad[3]), + (pad[0], pad[1]), + (0, 0), + ] + + if mode == "constant": + out = np.pad(input_data, pad, mode=mode, constant_values=value) + elif mode == "reflect": + out = np.pad(input_data, pad, mode=mode) + elif mode == "replicate": + out = np.pad(input_data, pad, mode="edge") + elif mode == "circular": + out = np.pad(input_data, pad, mode="wrap") + + return out + + def setUp(self): + self.places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def check_static_result(self, place): + with paddle.program_guard(paddle.Program(), paddle.Program()): + input_shape = (1, 2, 3, 4, 5) + pad = [1, 2, 1, 1, 3, 4] + mode = "constant" + value = 100 + input_data = np.random.rand(*input_shape).astype(np.float32) + x = paddle.data(name="x", shape=input_shape) + result = F.pad(input=x, pad=pad, value=value, mode='constant') + exe = paddle.Executor(place) + fetches = exe.run(paddle.default_main_program(), + feed={"x": input_data}, + fetch_list=[result]) + + np_out = self._get_numpy_out(input_data, pad, mode, value) + + self.assertTrue(np.allclose(fetches[0], np_out)) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + input_shape = (1, 2, 3, 4, 5) + pad = [1, 2, 1, 1, 3, 4] + mode = "constant" + value = 100 + input_data = np.random.rand(*input_shape).astype(np.float32) + np_out = self._get_numpy_out(input_data, pad, mode, value) + with dg.guard(place) as g: + input = dg.to_variable(input_data) + output = F.pad(input=input, pad=pad, mode=mode, value=value) + self.assertTrue(np.allclose(output.numpy(), np_out)) + + +class TestPad1dClass(unittest.TestCase): + def _get_numpy_out(self, + input_data, + pad, + mode, + value=0.0, + data_format="NCL"): + if data_format == "NCL": + pad = [ + (0, 0), + (0, 0), + (pad[0], pad[1]), + ] + else: + pad = [ + (0, 0), + (pad[0], pad[1]), + (0, 0), + ] + + if mode == "constant": + out = np.pad(input_data, pad, mode=mode, constant_values=value) + elif mode == "reflect": + out = np.pad(input_data, pad, mode=mode) + elif mode == "replicate": + out = np.pad(input_data, pad, mode="edge") + + return out + + def setUp(self): + self.places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def test_class(self): + for place in self.places: + input_shape = (3, 4, 5) + pad = [1, 2] + value = 100 + input_data = np.random.rand(*input_shape).astype(np.float32) + pad_reflection = nn.ReflectionPad1d(pad=pad) + pad_replication = nn.ReplicationPad1d(pad=pad) + + pad_constant = nn.ConstantPad1d(pad=pad, value=value) + with dg.guard(place) as g: + data = paddle.fluid.dygraph.to_variable(input_data) + + output = pad_reflection(data) + np_out = self._get_numpy_out( + input_data, pad, "reflect", data_format="NCL") + self.assertTrue(np.allclose(output.numpy(), np_out)) + + output = pad_replication(data) + np_out = self._get_numpy_out( + input_data, pad, "replicate", data_format="NCL") + self.assertTrue(np.allclose(output.numpy(), np_out)) + + output = pad_constant(data) + np_out = self._get_numpy_out( + input_data, pad, "constant", value=value, data_format="NCL") + self.assertTrue(np.allclose(output.numpy(), np_out)) + + +class TestPad3dOpError(unittest.TestCase): + def test_errors(self): + def test_variable(): + input_shape = (1, 2, 3, 4, 5) + data = np.random.rand(*input_shape).astype(np.float32) + F.pad(input=data, paddings=[1, 1, 1, 1, 1, 1]) + + def test_reflect_1(): + input_shape = (1, 2, 3, 4, 5) + data = np.random.rand(*input_shape).astype(np.float32) + x = paddle.data(name="x", shape=input_shape) + y = F.pad(x, pad=[5, 6, 1, 1, 1, 1], value=1, mode='reflect') + place = paddle.CPUPlace() + exe = paddle.Executor(place) + outputs = exe.run(feed={'x': data}, fetch_list=[y.name]) + + def test_reflect_2(): + input_shape = (1, 2, 3, 4, 5) + data = np.random.rand(*input_shape).astype(np.float32) + x = paddle.data(name="x", shape=input_shape) + y = F.pad(x, pad=[1, 1, 4, 3, 1, 1], value=1, mode='reflect') + place = paddle.CPUPlace() + exe = paddle.Executor(place) + outputs = exe.run(feed={'x': data}, fetch_list=[y.name]) + + def test_reflect_3(): + input_shape = (1, 2, 3, 4, 5) + data = np.random.rand(*input_shape).astype(np.float32) + x = paddle.data(name="x", shape=input_shape) + y = F.pad(x, pad=[1, 1, 1, 1, 2, 3], value=1, mode='reflect') + place = paddle.CPUPlace() + exe = paddle.Executor(place) + outputs = exe.run(feed={'x': data}, fetch_list=[y.name]) + + self.assertRaises(TypeError, test_variable) + + self.assertRaises(Exception, test_reflect_1) + + self.assertRaises(Exception, test_reflect_2) + + self.assertRaises(Exception, test_reflect_3) + + +if __name__ == '__main__': + unittest.main() From ac412e7f9a28a9794d1f4f019703c43513c6ba78 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 11 Aug 2020 09:08:01 +0000 Subject: [PATCH 04/15] restore pad2d --- python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/common.py | 4 +- python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/common.py | 82 ++++++++++++++++++++++++++- 4 files changed, 85 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index c41270c96503c..4223b0507fa76 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -60,6 +60,7 @@ from .layer.activation import HSigmoid #DEFINE_ALIAS from .layer.common import BilinearTensorProduct #DEFINE_ALIAS from .layer.common import Pool2D #DEFINE_ALIAS +from .layer.common import Pad2D #DEFINE_ALIAS from .layer.common import Pad #DEFINE_ALIAS from .layer.common import ReflectionPad1d #DEFINE_ALIAS from .layer.common import ReplicationPad1d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 32a191b47d8e7..2bd009c4117e4 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -491,8 +491,8 @@ def pad(input, Examples: .. code-block:: text - Input = [[[[1., 2., 3.], - [4., 5., 6.]]]] + Input = [[[[[1., 2., 3.], + [4., 5., 6.]]]]] Case 0: pad = [2, 2, 1, 1, 0, 0], diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index f09d655b07100..b6933a1545cc9 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -36,6 +36,7 @@ from .activation import HSigmoid #DEFINE_ALIAS from .common import BilinearTensorProduct #DEFINE_ALIAS from .common import Pool2D #DEFINE_ALIAS +from .common import Pad2D #DEFINE_ALIAS from .common import Pad #DEFINE_ALIAS from .common import ReflectionPad1d #DEFINE_ALIAS from .common import ReplicationPad1d #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index e98eafebbc498..d26b0d7bbf598 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -23,7 +23,7 @@ __all__ = [ 'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample', 'Pad', - 'ReflectionPad1d', 'ReplicationPad1d', 'ConstantPad1d' + 'Pad2D', 'ReflectionPad1d', 'ReplicationPad1d', 'ConstantPad1d' ] @@ -254,6 +254,86 @@ def forward(self, input): return out +class Pad2D(layers.Layer): + """ + :alias_main: paddle.nn.Pad2D + :alias: paddle.nn.Pad2D,paddle.nn.layer.Pad2D,paddle.nn.layer.common.Pad2D + This interface is used to construct a callable object of the ``Pad2D`` class. + The Pad2D layer pads the input tensor boundaries according to 'paddings' and 'mode'. + If mode is 'reflect', paddings[0] and paddings[1] must be no greater + than height-1. And the width dimension has the same condition. + Parameters: + paddings (int | List[int32]): The padding size. If padding is a int, uses the same + padding in all boundaries, if padding is a List, it must contain four integers, + (padding_top, padding_bottom, padding_left, padding_right). + Default is [0, 0, 0, 0]. + mode (str): Three modes: 'constant' (default), 'reflect', 'edge' . + When in 'constant' mode, this op uses a constant value to pad the input tensor. + When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor. + When in 'edge' mode, uses input boundaries to pad the input tensor. + Default is 'constant' + pad_value (float32): The value to fill the padded areas in 'constant' mode . Default is 0.0 + data_format (str): An string from: "NHWC", "NCHW". Specify the data format of + the input data. + Default is "NCHW" + Returns: + None + Examples: + .. code-block:: text + Input = [[[[1., 2., 3.], + [4., 5., 6.]]]] + Case 0: + paddings = [0, 1, 2, 3], + mode = 'constant' + pad_value = 0 + Out = [[[[0., 0., 1., 2., 3., 0., 0., 0.], + [0., 0., 4., 5., 6., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.]]]] + Case 1: + paddings = [0, 1, 2, 1], + mode = 'reflect' + Out = [[[[3., 2., 1., 2., 3., 2.], + [6., 5., 4., 5., 6., 5.], + [3., 2., 1., 2., 3., 2.]]]] + Case 2: + paddings = [0, 1, 2, 1], + mode = 'edge' + Out = [[[[1., 1., 1., 2., 3., 3.], + [4., 4., 4., 5., 6., 6.], + [4., 4., 4., 5., 6., 6.]]]] + Code Examples: + .. code-block:: python + import paddle.fluid as fluid + import paddle.nn as nn + import numpy as np + data = np.ones((2, 2, 2, 2)).astype('float32') + my_pad = nn.Pad2D(paddings=[1, 1, 1, 1]) + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + result = my_pad(data) + """ + + def __init__(self, + paddings=0, + mode='constant', + pad_value=0.0, + data_format="NCHW"): + super(Pad2D, self).__init__() + self._mode = mode + self._pad_value = pad_value + self._data_format = data_format + self._paddings = [paddings] * 4 if isinstance(paddings, + int) else paddings + + def forward(self, input): + return F.pad2d( + input, + paddings=self._paddings, + mode=self._mode, + pad_value=self._pad_value, + data_format=self._data_format) + + class Pad(layers.Layer): """ :alias_main: paddle.nn.Pad From 2d423cdd831d74e0dd577eedbcd06ba57af25581 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 11 Aug 2020 11:37:53 +0000 Subject: [PATCH 05/15] test=develop, fix paddl declare --- .../paddle/fluid/tests/unittests/test_pad3d_op.py | 15 +++++++-------- python/paddle/nn/functional/common.py | 4 ++++ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index 4cf27184b4194..09b5d4ea58331 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -21,7 +21,7 @@ import paddle.fluid.dygraph as dg import paddle.fluid.core as core -from paddle.fluid import Program, program_guard +from paddle.fluid import Program, program_guard, Executor, default_main_program class TestPad3dOp(OpTest): @@ -70,7 +70,6 @@ def setUp(self): elif self.mode == "circular": out = np.pad(self.inputs['X'], paddings, mode="wrap") self.outputs = {'Out': out} - print("[gry debug]outputs shape: ", out.shape) def test_check_output(self): self.check_output() @@ -242,7 +241,7 @@ def setUp(self): self.places.append(paddle.CUDAPlace(0)) def check_static_result(self, place): - with paddle.program_guard(paddle.Program(), paddle.Program()): + with program_guard(Program(), Program()): input_shape = (1, 2, 3, 4, 5) pad = [1, 2, 1, 1, 3, 4] mode = "constant" @@ -250,8 +249,8 @@ def check_static_result(self, place): input_data = np.random.rand(*input_shape).astype(np.float32) x = paddle.data(name="x", shape=input_shape) result = F.pad(input=x, pad=pad, value=value, mode='constant') - exe = paddle.Executor(place) - fetches = exe.run(paddle.default_main_program(), + exe = Executor(place) + fetches = exe.run(default_main_program(), feed={"x": input_data}, fetch_list=[result]) @@ -353,7 +352,7 @@ def test_reflect_1(): x = paddle.data(name="x", shape=input_shape) y = F.pad(x, pad=[5, 6, 1, 1, 1, 1], value=1, mode='reflect') place = paddle.CPUPlace() - exe = paddle.Executor(place) + exe = Executor(place) outputs = exe.run(feed={'x': data}, fetch_list=[y.name]) def test_reflect_2(): @@ -362,7 +361,7 @@ def test_reflect_2(): x = paddle.data(name="x", shape=input_shape) y = F.pad(x, pad=[1, 1, 4, 3, 1, 1], value=1, mode='reflect') place = paddle.CPUPlace() - exe = paddle.Executor(place) + exe = Executor(place) outputs = exe.run(feed={'x': data}, fetch_list=[y.name]) def test_reflect_3(): @@ -371,7 +370,7 @@ def test_reflect_3(): x = paddle.data(name="x", shape=input_shape) y = F.pad(x, pad=[1, 1, 1, 1, 2, 3], value=1, mode='reflect') place = paddle.CPUPlace() - exe = paddle.Executor(place) + exe = Executor(place) outputs = exe.run(feed={'x': data}, fetch_list=[y.name]) self.assertRaises(TypeError, test_variable) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 2bd009c4117e4..5bfd27e169cd6 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -587,6 +587,10 @@ def pad(input, pad = concat([pad, zeros((2, ), dtype="int32")], axis=0) unsqueezed_dim = [1] input = unsqueeze(input, axes=unsqueezed_dim) + else: + raise ValueError, "data_format should be in one of " + "[NCL, NCHW, NCDHW, NLC, NHWC, NDHWC] but got {}".format( + data_format) else: if data_format in ["NCL", "NCHW", "NCDHW"]: data_format = "NCDHW" From 55c1189776994cdad5ec5206653329da67409c92 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 11 Aug 2020 12:21:37 +0000 Subject: [PATCH 06/15] fix pad interface --- .../fluid/tests/unittests/test_pad3d_op.py | 7 +++--- python/paddle/nn/layer/common.py | 22 +++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index 09b5d4ea58331..080b043d0070f 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -316,10 +316,11 @@ def test_class(self): pad = [1, 2] value = 100 input_data = np.random.rand(*input_shape).astype(np.float32) - pad_reflection = nn.ReflectionPad1d(pad=pad) - pad_replication = nn.ReplicationPad1d(pad=pad) - pad_constant = nn.ConstantPad1d(pad=pad, value=value) + pad_reflection = nn.ReflectionPad1d(padding=pad) + pad_replication = nn.ReplicationPad1d(padding=pad) + pad_constant = nn.ConstantPad1d(padding=pad, value=value) + with dg.guard(place) as g: data = paddle.fluid.dygraph.to_variable(input_data) diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index d26b0d7bbf598..7c40ea45166a6 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -340,8 +340,8 @@ class Pad(layers.Layer): :alias: paddle.nn.Pad This interface is used to construct a callable object of the ``Pad`` class. - The Pad layer pads the input tensor boundaries according to 'paddings' and 'mode'. - If mode is 'reflect', paddings[0] and paddings[1] must be no greater + The Pad layer pads the input tensor boundaries according to 'pad' and 'mode'. + If mode is 'reflect', pad[0] and pad[1] must be no greater than width-1. The height and depth dimensions have the same condition. Parameters: @@ -482,7 +482,7 @@ class ReflectionPad1d(layers.Layer): input_shape = (1, 2, 3) pad = [1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 - my_pad = nn.ReflectionPad1d(pad=pad) + my_pad = nn.ReflectionPad1d(padding=pad) with fluid.dygraph.guard(): data = fluid.dygraph.to_variable(data) result = my_pad(data) @@ -491,11 +491,11 @@ class ReflectionPad1d(layers.Layer): # [5. 4. 5. 6. 5. 4.]]] """ - def __init__(self, pad=[0, 0], data_format="NCL", name=None): + def __init__(self, padding=[0, 0], data_format="NCL", name=None): super(ReflectionPad1d, self).__init__() self._mode = "reflect" self._data_format = data_format - self._pad = pad + self._pad = padding self._name = name def forward(self, input): @@ -543,7 +543,7 @@ class ReplicationPad1d(layers.Layer): input_shape = (1, 2, 3) pad = [1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 - my_pad = nn.ReplicationPad1d(pad=pad) + my_pad = nn.ReplicationPad1d(padding=pad) with fluid.dygraph.guard(): data = fluid.dygraph.to_variable(data) result = my_pad(data) @@ -552,11 +552,11 @@ class ReplicationPad1d(layers.Layer): # [1. 4. 5. 6. 6. 6.]]] """ - def __init__(self, pad=[0, 0], data_format="NCL", name=None): + def __init__(self, padding=[0, 0], data_format="NCL", name=None): super(ReplicationPad1d, self).__init__() self._mode = "replicate" self._data_format = data_format - self._pad = pad + self._pad = padding self._name = name def forward(self, input): @@ -606,7 +606,7 @@ class ConstantPad1d(layers.Layer): pad = [1, 2] value = 0.0 data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 - my_pad = nn.ConstantPad1d(pad=pad, value=value) + my_pad = nn.ConstantPad1d(padding=pad, value=value) with fluid.dygraph.guard(): data = fluid.dygraph.to_variable(data) result = my_pad(data) @@ -615,11 +615,11 @@ class ConstantPad1d(layers.Layer): # [0. 4. 5. 6. 0. 0.]]] """ - def __init__(self, pad=[0, 0], value=0.0, data_format="NCL", name=None): + def __init__(self, padding=[0, 0], value=0.0, data_format="NCL", name=None): super(ConstantPad1d, self).__init__() self._mode = "constant" self._data_format = data_format - self._pad = pad + self._pad = padding self._value = value self._name = name From 98d11de654f756a316f89d915ef277caf34ae906 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 11 Aug 2020 12:24:59 +0000 Subject: [PATCH 07/15] test=develop, fix pad --- python/paddle/nn/functional/common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 5bfd27e169cd6..4b92a6a37c3f4 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -633,10 +633,8 @@ def pad(input, dtype = helper.input_dtype(input_param_name='input') out = helper.create_variable_for_type_inference(dtype) - helper.append_op( type='pad3d', inputs=inputs, outputs={"Out": out}, attrs=attrs) - if len(unsqueezed_dim) != 0: out = squeeze(out, axes=unsqueezed_dim) From 645fe9090fcd314afaf997f259d6b0992e3b75ca Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Sun, 16 Aug 2020 18:24:34 +0000 Subject: [PATCH 08/15] test=develop, add all pad api and cos_sim --- paddle/fluid/operators/pad3d_op.cc | 867 ++++++++---------- paddle/fluid/operators/pad3d_op.cu | 21 +- .../fluid/tests/unittests/test_pad3d_op.py | 8 +- python/paddle/nn/__init__.py | 7 + python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/common.py | 187 ++-- python/paddle/nn/layer/__init__.py | 7 + python/paddle/nn/layer/common.py | 574 ++++++++++-- python/paddle/tensor/__init__.py | 1 + 9 files changed, 1052 insertions(+), 621 deletions(-) diff --git a/paddle/fluid/operators/pad3d_op.cc b/paddle/fluid/operators/pad3d_op.cc index 73d925fd6ed9a..1d41b823b6551 100644 --- a/paddle/fluid/operators/pad3d_op.cc +++ b/paddle/fluid/operators/pad3d_op.cc @@ -25,228 +25,192 @@ namespace operators { using framework::Tensor; template -void Pad3DConstNCDHW(const T* in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, const int pad_left, - T value, T* out_data) { - for (int n = 0; n < num; ++n) { - for (int c = 0; c < channels; ++c) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - out_data[out_d * out_height * out_width + out_h * out_width + - out_w] = - (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || - in_h >= in_height || in_w >= in_width) - ? value - : in_data[in_d * in_height * in_width + in_h * in_width + - in_w]; - } - } - } - in_data += in_depth * in_height * in_width; - out_data += out_depth * out_height * out_width; - } - } +void ConstPad3DFuncNCDHW(const T* in_data, T* out_data, const int in_depth, + const int in_height, const int in_width, + const int out_depth, const int out_height, + const int out_width, const int pad_front, + const int pad_top, const int pad_left, const int out_d, + const int out_h, const int out_w, const T value) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + out_data[out_d * out_height * out_width + out_h * out_width + out_w] = + (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width) + ? value + : in_data[in_d * in_height * in_width + in_h * in_width + in_w]; } template -void Pad3DConstNDHWC(const T* in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, const int pad_left, - T value, T* out_data) { - for (int n = 0; n < num; ++n) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * - channels; - if (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || - in_h >= in_height || in_w >= in_width) { - for (int c = 0; c < channels; ++c) { - out_data[out_index + c] = value; - } - } else { - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * - channels; - for (int c = 0; c < channels; ++c) { - out_data[out_index + c] = in_data[in_index + c]; - } - } - } - } +void ConstPad3DFuncNDHWC(const T* in_data, T* out_data, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const int out_d, const int out_h, + const int out_w, const T value) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + if (in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width) { + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = value; + } + } else { + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; } - in_data += in_depth * in_height * in_width * channels; - out_data += out_depth * out_height * out_width * channels; } } template -void Pad3DReflectNCDHW(const T* in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { - for (int n = 0; n < num; ++n) { - for (int c = 0; c < channels; ++c) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = std::max(in_d, -in_d); // reflect by 0 - in_d = - std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth - in_h = std::max(in_h, -in_h); // reflect by 0 - in_h = std::min(in_h, - 2 * in_height - in_h - 2); // reflect by in_height - in_w = std::max(in_w, -in_w); // reflect by 0 - in_w = - std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width - - out_data[out_d * out_height * out_width + out_h * out_width + - out_w] = - in_data[in_d * in_height * in_width + in_h * in_width + in_w]; - } - } - } - in_data += in_depth * in_height * in_width; - out_data += out_depth * out_height * out_width; - } - } +void ReflectPad3DFuncNCDHW(const T* in_data, T* out_data, const int in_depth, + const int in_height, const int in_width, + const int out_depth, const int out_height, + const int out_width, const int pad_front, + const int pad_top, const int pad_left, + const int out_d, const int out_h, const int out_w, + const T value) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); // reflect by 0 + in_d = std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth + in_h = std::max(in_h, -in_h); // reflect by 0 + in_h = std::min(in_h, 2 * in_height - in_h - 2); // reflect by in_height + in_w = std::max(in_w, -in_w); // reflect by 0 + in_w = std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width + + out_data[out_d * out_height * out_width + out_h * out_width + out_w] = + in_data[in_d * in_height * in_width + in_h * in_width + in_w]; } template -void Pad3DReflectNDHWC(const T* in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { - for (int n = 0; n < num; ++n) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = std::max(in_d, -in_d); - in_d = std::min(in_d, 2 * in_depth - in_d - 2); - in_h = std::max(in_h, -in_h); - in_h = std::min(in_h, 2 * in_height - in_h - 2); - in_w = std::max(in_w, -in_w); - in_w = std::min(in_w, 2 * in_width - in_w - 2); - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * - channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - out_data[out_index + c] = in_data[in_index + c]; - } - } - } - } - in_data += in_depth * in_height * in_width * channels; - out_data += out_depth * out_height * out_width * channels; +void ReflectPad3DFuncNDHWC(const T* in_data, T* out_data, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const int out_d, const int out_h, + const int out_w, const T value) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); + in_d = std::min(in_d, 2 * in_depth - in_d - 2); + in_h = std::max(in_h, -in_h); + in_h = std::min(in_h, 2 * in_height - in_h - 2); + in_w = std::max(in_w, -in_w); + in_w = std::min(in_w, 2 * in_width - in_w - 2); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; } } template -void Pad3DReplicateNCDHW(const T* in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { - for (int n = 0; n < num; ++n) { - for (int c = 0; c < channels; ++c) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); - int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); - int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); +void ReplicatePad3DFuncNCDHW(const T* in_data, T* out_data, const int in_depth, + const int in_height, const int in_width, + const int out_depth, const int out_height, + const int out_width, const int pad_front, + const int pad_top, const int pad_left, + const int out_d, const int out_h, const int out_w, + const T value) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + out_data[out_d * out_height * out_width + out_h * out_width + out_w] = + in_data[in_d * in_height * in_width + in_h * in_width + in_w]; +} - out_data[out_d * out_height * out_width + out_h * out_width + - out_w] = - in_data[in_d * in_height * in_width + in_h * in_width + in_w]; - } - } - } - in_data += in_depth * in_height * in_width; - out_data += out_depth * out_height * out_width; - } +template +void ReplicatePad3DFuncNDHWC(const T* in_data, T* out_data, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const int out_d, + const int out_h, const int out_w, const T value) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; } } template -void Pad3DReplicateNDHWC(const T* in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { - for (int n = 0; n < num; ++n) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); - int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); - int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * - channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - out_data[out_index + c] = in_data[in_index + c]; - } - } - } - } - in_data += in_depth * in_height * in_width * channels; - out_data += out_depth * out_height * out_width * channels; +void CircularPad3DFuncNCDHW(const T* in_data, T* out_data, const int in_depth, + const int in_height, const int in_width, + const int out_depth, const int out_height, + const int out_width, const int pad_front, + const int pad_top, const int pad_left, + const int out_d, const int out_h, const int out_w, + const T value) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + out_data[out_d * out_height * out_width + out_h * out_width + out_w] = + in_data[in_d * in_height * in_width + in_h * in_width + in_w]; +} + +template +void CircularPad3DFuncNDHWC(const T* in_data, T* out_data, const int channels, + const int in_depth, const int in_height, + const int in_width, const int out_depth, + const int out_height, const int out_width, + const int pad_front, const int pad_top, + const int pad_left, const int out_d, + const int out_h, const int out_w, const T value) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; } } template -void Pad3DCircularNCDHW(const T* in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { +void Pad3DNCDHW(const T* in_data, const int num, const int channels, + const int in_depth, const int in_height, const int in_width, + const int out_depth, const int out_height, const int out_width, + const int pad_front, const int pad_top, const int pad_left, + T value, T* out_data, + void (*pad_func)(const T*, T*, const int, const int, const int, + const int, const int, const int, const int, + const int, const int, const int, const int, + const int, const T)) { for (int n = 0; n < num; ++n) { for (int c = 0; c < channels; ++c) { for (int out_d = 0; out_d < out_depth; ++out_d) { for (int out_h = 0; out_h < out_height; ++out_h) { for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - - out_data[out_d * out_height * out_width + out_h * out_width + - out_w] = - in_data[in_d * in_height * in_width + in_h * in_width + in_w]; + pad_func(in_data, out_data, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, + pad_left, out_d, out_h, out_w, value); } } } @@ -257,28 +221,22 @@ void Pad3DCircularNCDHW(const T* in_data, const int num, const int channels, } template -void Pad3DCircularNDHWC(const T* in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, T* out_data) { +void Pad3DNDHWC(const T* in_data, const int num, const int channels, + const int in_depth, const int in_height, const int in_width, + const int out_depth, const int out_height, const int out_width, + const int pad_front, const int pad_top, const int pad_left, + T value, T* out_data, + void (*pad_func)(const T*, T*, const int, const int, const int, + const int, const int, const int, const int, + const int, const int, const int, const int, + const int, const int, const T)) { for (int n = 0; n < num; ++n) { for (int out_d = 0; out_d < out_depth; ++out_d) { for (int out_h = 0; out_h < out_height; ++out_h) { for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * - channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - out_data[out_index + c] = in_data[in_index + c]; - } + pad_func(in_data, out_data, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, + pad_left, out_d, out_h, out_w, value); } } } @@ -288,223 +246,189 @@ void Pad3DCircularNDHWC(const T* in_data, const int num, const int channels, } template -void Pad3DGradConstNCDHW(T* d_in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { - for (int n = 0; n < num; ++n) { - for (int c = 0; c < channels; ++c) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || - in_h >= in_height || in_w >= in_width)) { - d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] = - d_out_data[out_d * out_height * out_width + - out_h * out_width + out_w]; - } - } - } - } - d_in_data += in_depth * in_height * in_width; - d_out_data += out_depth * out_height * out_width; - } +void ConstPad3DGradNCDHW(T* d_in_data, const T* d_out_data, const int in_depth, + const int in_height, const int in_width, + const int out_depth, const int out_height, + const int out_width, const int pad_front, + const int pad_top, const int pad_left, const int out_d, + const int out_h, const int out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width)) { + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] = + d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; } } template -void Pad3DGradConstNDHWC(T* d_in_data, const int num, const int channels, +void ConstPad3DGradNDHWC(T* d_in_data, const T* d_out_data, const int channels, const int in_depth, const int in_height, const int in_width, const int out_depth, const int out_height, const int out_width, const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { - for (int n = 0; n < num; ++n) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * - channels; - if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || - in_h >= in_height || in_w >= in_width)) { - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * - channels; - for (int c = 0; c < channels; ++c) { - d_in_data[in_index + c] = d_out_data[out_index + c]; - } - } - } - } + const int pad_left, const int out_d, const int out_h, + const int out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + if (!(in_d < 0 || in_h < 0 || in_w < 0 || in_d >= in_depth || + in_h >= in_height || in_w >= in_width)) { + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] = d_out_data[out_index + c]; } - d_in_data += in_depth * in_height * in_width * channels; - d_out_data += out_depth * out_height * out_width * channels; } } template -void Pad3DGradReflectNCDHW(T* d_in_data, const int num, const int channels, +void ReflectPad3DGradNCDHW(T* d_in_data, const T* d_out_data, const int in_depth, const int in_height, const int in_width, const int out_depth, const int out_height, const int out_width, const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { - for (int n = 0; n < num; ++n) { - for (int c = 0; c < channels; ++c) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = std::max(in_d, -in_d); // reflect by 0 - in_d = - std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth - in_h = std::max(in_h, -in_h); // reflect by 0 - in_h = std::min(in_h, - 2 * in_height - in_h - 2); // reflect by in_height - in_w = std::max(in_w, -in_w); // reflect by 0 - in_w = - std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width - - d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += - d_out_data[out_d * out_height * out_width + out_h * out_width + - out_w]; - } - } - } - d_in_data += in_depth * in_height * in_width; - d_out_data += out_depth * out_height * out_width; - } - } + const int pad_left, const int out_d, const int out_h, + const int out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); // reflect by 0 + in_d = std::min(in_d, 2 * in_depth - in_d - 2); // reflect by in_depth + in_h = std::max(in_h, -in_h); // reflect by 0 + in_h = std::min(in_h, 2 * in_height - in_h - 2); // reflect by in_height + in_w = std::max(in_w, -in_w); // reflect by 0 + in_w = std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width + + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += + d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; } template -void Pad3DGradReflectNDHWC(T* d_in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { - for (int n = 0; n < num; ++n) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = out_d - pad_front; - int in_h = out_h - pad_top; - int in_w = out_w - pad_left; - - in_d = std::max(in_d, -in_d); - in_d = std::min(in_d, 2 * in_depth - in_d - 2); - in_h = std::max(in_h, -in_h); - in_h = std::min(in_h, 2 * in_height - in_h - 2); - in_w = std::max(in_w, -in_w); - in_w = std::min(in_w, 2 * in_width - in_w - 2); - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * - channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - d_in_data[in_index + c] += d_out_data[out_index + c]; - } - } - } - } - d_in_data += in_depth * in_height * in_width * channels; - d_out_data += out_depth * out_height * out_width * channels; +void ReflectPad3DGradNDHWC(T* d_in_data, const T* d_out_data, + const int channels, const int in_depth, + const int in_height, const int in_width, + const int out_depth, const int out_height, + const int out_width, const int pad_front, + const int pad_top, const int pad_left, + const int out_d, const int out_h, const int out_w) { + int in_d = out_d - pad_front; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + + in_d = std::max(in_d, -in_d); + in_d = std::min(in_d, 2 * in_depth - in_d - 2); + in_h = std::max(in_h, -in_h); + in_h = std::min(in_h, 2 * in_height - in_h - 2); + in_w = std::max(in_w, -in_w); + in_w = std::min(in_w, 2 * in_width - in_w - 2); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; } } template -void Pad3DGradReplicateNCDHW(T* d_in_data, const int num, const int channels, +void ReplicatePad3DGradNCDHW(T* d_in_data, const T* d_out_data, const int in_depth, const int in_height, const int in_width, const int out_depth, const int out_height, const int out_width, const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { - for (int n = 0; n < num; ++n) { - for (int c = 0; c < channels; ++c) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); - int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); - int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); - - d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += - d_out_data[out_d * out_height * out_width + out_h * out_width + - out_w]; - } - } - } - d_in_data += in_depth * in_height * in_width; - d_out_data += out_depth * out_height * out_width; - } - } + const int pad_left, const int out_d, + const int out_h, const int out_w) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += + d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; } template -void Pad3DGradReplicateNDHWC(T* d_in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { - for (int n = 0; n < num; ++n) { - for (int out_d = 0; out_d < out_depth; ++out_d) { - for (int out_h = 0; out_h < out_height; ++out_h) { - for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); - int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); - int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * - channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - d_in_data[in_index + c] += d_out_data[out_index + c]; - } - } - } - } - d_in_data += in_depth * in_height * in_width * channels; - d_out_data += out_depth * out_height * out_width * channels; +void ReplicatePad3DGradNDHWC(T* d_in_data, const T* d_out_data, + const int channels, const int in_depth, + const int in_height, const int in_width, + const int out_depth, const int out_height, + const int out_width, const int pad_front, + const int pad_top, const int pad_left, + const int out_d, const int out_h, + const int out_w) { + int in_d = std::min(in_depth - 1, std::max(out_d - pad_front, 0)); + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; } } template -void Pad3DGradCircularNCDHW(T* d_in_data, const int num, const int channels, +void CircularPad3DGradNCDHW(T* d_in_data, const T* d_out_data, const int in_depth, const int in_height, const int in_width, const int out_depth, const int out_height, const int out_width, const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { + const int pad_left, const int out_d, + const int out_h, const int out_w) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += + d_out_data[out_d * out_height * out_width + out_h * out_width + out_w]; +} + +template +void CircularPad3DGradNDHWC(T* d_in_data, const T* d_out_data, + const int channels, const int in_depth, + const int in_height, const int in_width, + const int out_depth, const int out_height, + const int out_width, const int pad_front, + const int pad_top, const int pad_left, + const int out_d, const int out_h, const int out_w) { + int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; + int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; + int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; + + const int out_index = + (out_d * out_height * out_width + out_h * out_width + out_w) * channels; + const int in_index = + (in_d * in_height * in_width + in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; + } +} + +template +void Pad3DGradNCDHW(T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, const int in_width, + const int out_depth, const int out_height, + const int out_width, const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data, + void (*pad_func)(T*, const T*, const int, const int, + const int, const int, const int, const int, + const int, const int, const int, const int, + const int, const int)) { for (int n = 0; n < num; ++n) { for (int c = 0; c < channels; ++c) { for (int out_d = 0; out_d < out_depth; ++out_d) { for (int out_h = 0; out_h < out_height; ++out_h) { for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - d_in_data[in_d * in_height * in_width + in_h * in_width + in_w] += - d_out_data[out_d * out_height * out_width + out_h * out_width + - out_w]; + pad_func(d_in_data, d_out_data, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, + pad_left, out_d, out_h, out_w); } } } @@ -515,28 +439,22 @@ void Pad3DGradCircularNCDHW(T* d_in_data, const int num, const int channels, } template -void Pad3DGradCircularNDHWC(T* d_in_data, const int num, const int channels, - const int in_depth, const int in_height, - const int in_width, const int out_depth, - const int out_height, const int out_width, - const int pad_front, const int pad_top, - const int pad_left, const T* d_out_data) { +void Pad3DGradNDHWC(T* d_in_data, const int num, const int channels, + const int in_depth, const int in_height, const int in_width, + const int out_depth, const int out_height, + const int out_width, const int pad_front, const int pad_top, + const int pad_left, const T* d_out_data, + void (*pad_func)(T*, const T*, const int, const int, + const int, const int, const int, const int, + const int, const int, const int, const int, + const int, const int, const int)) { for (int n = 0; n < num; ++n) { for (int out_d = 0; out_d < out_depth; ++out_d) { for (int out_h = 0; out_h < out_height; ++out_h) { for (int out_w = 0; out_w < out_width; ++out_w) { - int in_d = ((out_d - pad_front) % in_depth + in_depth) % in_depth; - int in_h = ((out_h - pad_top) % in_height + in_height) % in_height; - int in_w = ((out_w - pad_left) % in_width + in_width) % in_width; - - const int out_index = - (out_d * out_height * out_width + out_h * out_width + out_w) * - channels; - const int in_index = - (in_d * in_height * in_width + in_h * in_width + in_w) * channels; - for (int c = 0; c < channels; ++c) { - d_in_data[in_index + c] += d_out_data[out_index + c]; - } + pad_func(d_in_data, d_out_data, channels, in_depth, in_height, + in_width, out_depth, out_height, out_width, pad_front, + pad_top, pad_left, out_d, out_h, out_w); } } } @@ -545,29 +463,25 @@ void Pad3DGradCircularNDHWC(T* d_in_data, const int num, const int channels, } } -static inline void GetPaddings(int* paddings, - const framework::ExecutionContext& context) { +static inline std::vector GetPaddings( + const framework::ExecutionContext& context) { + std::vector paddings(6); auto* paddings_t = context.Input("Paddings"); if (paddings_t) { auto paddings_data = paddings_t->data(); - paddings[0] = paddings_data[0]; - paddings[1] = paddings_data[1]; - paddings[2] = paddings_data[2]; - paddings[3] = paddings_data[3]; - paddings[4] = paddings_data[4]; - paddings[5] = paddings_data[5]; + std::memcpy(paddings.data(), paddings_data, paddings.size() * sizeof(int)); } else { auto pads = context.Attr>("paddings"); - std::copy(pads.begin(), pads.end(), paddings); + std::copy(pads.begin(), pads.end(), paddings.data()); } + return paddings; } template class Pad3dCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - int pads[6]; - GetPaddings(pads, context); + std::vector pads = GetPaddings(context); auto mode = context.Attr("mode"); auto data_format = context.Attr("data_format"); T value = static_cast(context.Attr("value")); @@ -658,41 +572,33 @@ class Pad3dCPUKernel : public framework::OpKernel { const int pad_front = pads[4]; const int num = in_dims[0]; if (data_format == "NCDHW") { - if (mode == "reflect") { - Pad3DReflectNCDHW(in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, - pad_left, out_data); - } else if (mode == "replicate") { - Pad3DReplicateNCDHW(in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, out_data); - } else if (mode == "circular") { - Pad3DCircularNCDHW(in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, out_data); - } else if (mode == "constant") { - Pad3DConstNCDHW(in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, - pad_left, value, out_data); - } + std::map + func_map; + + func_map["reflect"] = ReflectPad3DFuncNCDHW; + func_map["replicate"] = ReplicatePad3DFuncNCDHW; + func_map["circular"] = CircularPad3DFuncNCDHW; + func_map["constant"] = ConstPad3DFuncNCDHW; + Pad3DNCDHW(in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + value, out_data, func_map[mode]); } else { - if (mode == "reflect") { - Pad3DReflectNDHWC(in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, - pad_left, out_data); - } else if (mode == "replicate") { - Pad3DReplicateNDHWC(in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, out_data); - } else if (mode == "circular") { - Pad3DCircularNDHWC(in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, out_data); - } else { - Pad3DConstNDHWC(in_data, num, channels, in_depth, in_height, in_width, - out_depth, out_height, out_width, pad_front, pad_top, - pad_left, value, out_data); - } + std::map + func_map; + + func_map["reflect"] = ReflectPad3DFuncNDHWC; + func_map["replicate"] = ReplicatePad3DFuncNDHWC; + func_map["circular"] = CircularPad3DFuncNDHWC; + func_map["constant"] = ConstPad3DFuncNDHWC; + Pad3DNDHWC(in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, pad_left, + value, out_data, func_map[mode]); } } }; @@ -701,8 +607,7 @@ template class Pad3dGradCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - int pads[6]; - GetPaddings(pads, context); + std::vector pads = GetPaddings(context); auto mode = context.Attr("mode"); auto data_format = context.Attr("data_format"); auto* d_out = context.Input(framework::GradVarName("Out")); @@ -726,23 +631,21 @@ class Pad3dGradCPUKernel : public framework::OpKernel { const int out_depth = d_out_dims[2]; const int out_height = d_out_dims[3]; const int out_width = d_out_dims[4]; - if (mode == "reflect") { - Pad3DGradReflectNCDHW(d_in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, d_out_data); - } else if (mode == "replicate") { - Pad3DGradReplicateNCDHW(d_in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, d_out_data); - } else if (mode == "circular") { - Pad3DGradCircularNCDHW(d_in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, d_out_data); - } else { - Pad3DGradConstNCDHW(d_in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, d_out_data); - } + + std::map + func_map; + + func_map["reflect"] = ReflectPad3DGradNCDHW; + func_map["replicate"] = ReplicatePad3DGradNCDHW; + func_map["circular"] = CircularPad3DGradNCDHW; + func_map["constant"] = ConstPad3DGradNCDHW; + + Pad3DGradNCDHW(d_in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, + pad_left, d_out_data, func_map[mode]); } else { const int channels = d_in_dims[4]; const int in_depth = d_in_dims[1]; @@ -751,23 +654,21 @@ class Pad3dGradCPUKernel : public framework::OpKernel { const int out_depth = d_out_dims[1]; const int out_height = d_out_dims[2]; const int out_width = d_out_dims[3]; - if (mode == "reflect") { - Pad3DGradReflectNDHWC(d_in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, d_out_data); - } else if (mode == "replicate") { - Pad3DGradReplicateNDHWC(d_in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, d_out_data); - } else if (mode == "circular") { - Pad3DGradCircularNDHWC(d_in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, d_out_data); - } else { - Pad3DGradConstNDHWC(d_in_data, num, channels, in_depth, in_height, - in_width, out_depth, out_height, out_width, - pad_front, pad_top, pad_left, d_out_data); - } + + std::map + func_map; + + func_map["reflect"] = ReflectPad3DGradNDHWC; + func_map["replicate"] = ReplicatePad3DGradNDHWC; + func_map["circular"] = CircularPad3DGradNDHWC; + func_map["constant"] = ConstPad3DGradNDHWC; + + Pad3DGradNDHWC(d_in_data, num, channels, in_depth, in_height, in_width, + out_depth, out_height, out_width, pad_front, pad_top, + pad_left, d_out_data, func_map[mode]); } } }; diff --git a/paddle/fluid/operators/pad3d_op.cu b/paddle/fluid/operators/pad3d_op.cu index 3ab811f4905a0..672a75389ccf1 100644 --- a/paddle/fluid/operators/pad3d_op.cu +++ b/paddle/fluid/operators/pad3d_op.cu @@ -511,31 +511,27 @@ __global__ void Pad3DGradCircularNDHWC(const int out_size, T* d_in_data, } } -static inline void GetPaddings(int* paddings, - const framework::ExecutionContext& context) { +static inline std::vector GetPaddings( + const framework::ExecutionContext& context) { + std::vector paddings(6); auto* paddings_data = context.Input("Paddings"); if (paddings_data) { Tensor pads; framework::TensorCopySync(*paddings_data, platform::CPUPlace(), &pads); auto pads_data = pads.data(); - paddings[0] = pads_data[0]; - paddings[1] = pads_data[1]; - paddings[2] = pads_data[2]; - paddings[3] = pads_data[3]; - paddings[4] = pads_data[4]; - paddings[5] = pads_data[5]; + std::memcpy(paddings.data(), pads_data, paddings.size() * sizeof(int)); } else { auto pads = context.Attr>("paddings"); - std::copy(pads.begin(), pads.end(), paddings); + std::copy(pads.begin(), pads.end(), paddings.data()); } + return paddings; } template class Pad3dCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - int pads[6]; - GetPaddings(pads, context); + std::vector pads = GetPaddings(context); auto mode = context.Attr("mode"); auto data_format = context.Attr("data_format"); T value = static_cast(context.Attr("value")); @@ -686,8 +682,7 @@ template class Pad3dGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - int pads[6]; - GetPaddings(pads, context); + std::vector pads = GetPaddings(context); auto mode = context.Attr("mode"); auto data_format = context.Attr("data_format"); auto* d_out = context.Input(framework::GradVarName("Out")); diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index 080b043d0070f..8b7eeaf53ffff 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -201,7 +201,7 @@ def test_dygraph(self): place = paddle.CPUPlace() with dg.guard(place) as g: input = dg.to_variable(input_data) - output = F.pad(input=input, pad=pad, mode=mode, value=value) + output = F.pad(x=input, pad=pad, mode=mode, value=value) self.assertTrue(np.allclose(output.numpy(), np_out)) @@ -248,7 +248,7 @@ def check_static_result(self, place): value = 100 input_data = np.random.rand(*input_shape).astype(np.float32) x = paddle.data(name="x", shape=input_shape) - result = F.pad(input=x, pad=pad, value=value, mode='constant') + result = F.pad(x=x, pad=pad, value=value, mode='constant') exe = Executor(place) fetches = exe.run(default_main_program(), feed={"x": input_data}, @@ -272,7 +272,7 @@ def test_dygraph(self): np_out = self._get_numpy_out(input_data, pad, mode, value) with dg.guard(place) as g: input = dg.to_variable(input_data) - output = F.pad(input=input, pad=pad, mode=mode, value=value) + output = F.pad(x=input, pad=pad, mode=mode, value=value) self.assertTrue(np.allclose(output.numpy(), np_out)) @@ -345,7 +345,7 @@ def test_errors(self): def test_variable(): input_shape = (1, 2, 3, 4, 5) data = np.random.rand(*input_shape).astype(np.float32) - F.pad(input=data, paddings=[1, 1, 1, 1, 1, 1]) + F.pad(x=data, paddings=[1, 1, 1, 1, 1, 1]) def test_reflect_1(): input_shape = (1, 2, 3, 4, 5) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 4223b0507fa76..e32c30573c0bb 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -65,6 +65,13 @@ from .layer.common import ReflectionPad1d #DEFINE_ALIAS from .layer.common import ReplicationPad1d #DEFINE_ALIAS from .layer.common import ConstantPad1d #DEFINE_ALIAS +from .layer.common import ReflectionPad2d #DEFINE_ALIAS +from .layer.common import ReplicationPad2d #DEFINE_ALIAS +from .layer.common import ConstantPad2d #DEFINE_ALIAS +from .layer.common import ZeroPad2d #DEFINE_ALIAS +from .layer.common import ReplicationPad3d #DEFINE_ALIAS +from .layer.common import ConstantPad3d #DEFINE_ALIAS +from .layer.common import CosineSimilarity #DEFINE_ALIAS from .layer.common import Embedding #DEFINE_ALIAS from .layer.common import Linear #DEFINE_ALIAS from .layer.common import Flatten #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 3fefb1b053ee8..7acfa454feb48 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -58,6 +58,7 @@ from .common import pad #DEFINE_ALIAS from .common import pad_constant_like #DEFINE_ALIAS from .common import pad2d #DEFINE_ALIAS +from .common import cosine_similarity #DEFINE_ALIAS from .common import unfold #DEFINE_ALIAS # from .common import bilinear_tensor_product #DEFINE_ALIAS from .common import assign #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 4b92a6a37c3f4..db90f81d056a5 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -13,6 +13,8 @@ # limitations under the License. import warnings +import paddle.fluid.core as core +from ...fluid.framework import in_dygraph_mode, core from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layers.tensor import Variable, fill_constant, zeros, concat @@ -25,6 +27,10 @@ from ...fluid.layers import assign #DEFINE_ALIAS from ...fluid.layers import squeeze from ...fluid.layers import unsqueeze +from ...fluid.layers import elementwise_mul +from ...tensor import clamp +from ...tensor import sum +from ...tensor import sqrt #from ...fluid.layers import fc #DEFINE_ALIAS from ...fluid.layers import pad_constant_like #DEFINE_ALIAS @@ -449,23 +455,19 @@ def _is_list_or_turple_(data): return out -def pad(input, +def pad(x, pad=[0, 0, 0, 0], mode='constant', value=0, data_format="NCHW", name=None): """ - :alias_main: paddle.nn.functional.pad - :alias: paddle.nn.functional.pad,paddle.nn.functional.common.pad - :old_api: paddle.fluid.layers.pad2d - Pad tensor according to 'pad' and 'mode'. If mode is 'reflect', pad[0] and pad[1] must be no greater than width-1. The height and depth dimension has the same condition. Parameters: - input (Variable): The input tensor with data type float32/double/int32/int64_t. + x (Variable): The input tensor with data type float32/double/int32/int64_t. pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions of input will be padded. 1. If input dimension is 3, then the pad has the form (pad_left, pad_right). 2. If the input dimension is 4, then the pad has the form (pad_left, pad_right, @@ -491,13 +493,13 @@ def pad(input, Examples: .. code-block:: text - Input = [[[[[1., 2., 3.], - [4., 5., 6.]]]]] + x = [[[[[1., 2., 3.], + [4., 5., 6.]]]]] Case 0: pad = [2, 2, 1, 1, 0, 0], mode = 'constant' - pad_value = 0 + value = 0 Out = [[[[[0. 0. 0. 0. 0. 0. 0.] [0. 0. 1. 2. 3. 0. 0.] [0. 0. 4. 5. 6. 0. 0.] @@ -529,39 +531,37 @@ def pad(input, Code Examples: .. code-block:: python - # declarative mode import numpy as np import paddle import paddle.nn.functional as F - - input_shape = (1, 1, 3) - data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 - x = paddle.data(name="x", shape=input_shape) - y = F.pad(x, pad=[2, 3], value=1, mode='constant') - place = paddle.CPUPlace() - exe = paddle.Executor(place) - outputs = exe.run(feed={'x': data}, fetch_list=[y.name]) - print(outputs[0]) + + paddle.disable_static() + + # example 1 + x_shape = (1, 1, 3) + x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + 1 + tensor_x = paddle.to_variable(x) + y = F.pad(tensor_x, pad=[2, 3], value=1, mode='constant') + print(y.numpy()) # [[[1. 1. 1. 2. 3. 1. 1. 1.]]] - # imperative mode - import paddle.fluid.dygraph as dg - input_shape = (1, 1, 2, 3) - pad = [1, 2, 1, 1] - input_data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 - with dg.guard(place) as g: - input = dg.to_variable(input_data) - output = paddle.nn.functional.pad(input=input, pad=pad, mode="circular") - print(output.numpy()) - # [[[[6. 4. 5. 6. 4. 5.] - # [3. 1. 2. 3. 1. 2.] - # [6. 4. 5. 6. 4. 5.] - # [3. 1. 2. 3. 1. 2.]]]] - + # example 2 + x_shape = (1, 1, 2, 3) + x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + 1 + tensor_x = paddle.to_variable(x) + y = F.pad(tensor_x, pad=[1, 2, 1, 1], value=1, mode='circular') + print(y.numpy()) + # [[[[6. 4. 5. 6. 4. 5.] + # [3. 1. 2. 3. 1. 2.] + # [6. 4. 5. 6. 4. 5.] + # [3. 1. 2. 3. 1. 2.]]]] """ + assert mode in ['reflect', 'replicate', 'constant', 'circular'], \ + "mode should be one of constant, reflect, replicate, circular, but got {}.".format(mode) + data_format = data_format.upper() - input_dim = len(input.shape) + x_dim = len(x.shape) original_data_format = data_format unsqueezed_dim = [] @@ -569,24 +569,24 @@ def pad(input, if isinstance(pad, Variable): if data_format in ["NCL", "NCHW", "NCDHW"]: data_format = "NCDHW" - if input_dim == 3: + if x_dim == 3: pad = concat([zeros((4, ), dtype="int32"), pad], axis=0) unsqueezed_dim = [3, 4] - input = unsqueeze(input, axes=unsqueezed_dim) - elif input_dim == 4: + x = unsqueeze(x, axes=unsqueezed_dim) + elif x_dim == 4: pad = concat([pad, zeros((2, ), dtype="int32")], axis=0) unsqueezed_dim = [2] - input = unsqueeze(input, axes=unsqueezed_dim) + x = unsqueeze(x, axes=unsqueezed_dim) elif data_format in ["NLC", "NHWC", "NDHWC"]: data_format = "NDHWC" - if input_dim == 3: + if x_dim == 3: pad = concat([zeros((4, ), dtype="int32"), pad], axis=0) unsqueezed_dim = [2, 3] - input = unsqueeze(input, axes=unsqueezed_dim) - elif input_dim == 4: + x = unsqueeze(x, axes=unsqueezed_dim) + elif x_dim == 4: pad = concat([pad, zeros((2, ), dtype="int32")], axis=0) unsqueezed_dim = [1] - input = unsqueeze(input, axes=unsqueezed_dim) + x = unsqueeze(x, axes=unsqueezed_dim) else: raise ValueError, "data_format should be in one of " "[NCL, NCHW, NCDHW, NLC, NHWC, NDHWC] but got {}".format( @@ -594,48 +594,107 @@ def pad(input, else: if data_format in ["NCL", "NCHW", "NCDHW"]: data_format = "NCDHW" - if input_dim == 3: + if x_dim == 3: pad = [0, 0, 0, 0] + pad unsqueezed_dim = [3, 4] - input = unsqueeze(input, axes=unsqueezed_dim) - elif input_dim == 4: + x = unsqueeze(x, axes=unsqueezed_dim) + elif x_dim == 4: pad = pad + [0, 0] unsqueezed_dim = [2] - input = unsqueeze(input, axes=unsqueezed_dim) + x = unsqueeze(x, axes=unsqueezed_dim) elif data_format in ["NLC", "NHWC", "NDHWC"]: data_format = "NDHWC" - if input_dim == 3: + if x_dim == 3: pad = [0, 0, 0, 0] + pad unsqueezed_dim = [2, 3] - input = unsqueeze(input, axes=unsqueezed_dim) - elif input_dim == 4: + x = unsqueeze(x, axes=unsqueezed_dim) + elif x_dim == 4: pad = pad + [0, 0] unsqueezed_dim = [1] - input = unsqueeze(input, axes=unsqueezed_dim) + x = unsqueeze(x, axes=unsqueezed_dim) else: raise ValueError, "data_format should be in one of " "[NCL, NCHW, NCDHW, NLC, NHWC, NDHWC] but got {}".format( data_format) - attrs = {'mode': mode, 'value': value, 'data_format': data_format} - - inputs = {'X': [input]} - if isinstance(pad, Variable): - inputs['Paddings'] = [pad] - attrs['paddings'] = [] + if in_dygraph_mode(): + if isinstance(pad, Variable): + out = core.ops.pad3d(x, pad, "mode", mode, "value", value, + "data_format", data_format, "name", name) + else: + out = core.ops.pad3d(x, "paddings", pad, "mode", mode, "value", + value, "data_format", data_format, "name", + name) else: - attrs['paddings'] = pad + attrs = {'mode': mode, 'value': value, 'data_format': data_format} + inputs = {'X': [x]} + if isinstance(pad, Variable): + inputs['Paddings'] = [pad] + attrs['paddings'] = [] + else: + attrs['paddings'] = pad - helper = LayerHelper('pad3d', **locals()) + helper = LayerHelper('pad3d', **locals()) - assert mode in ['reflect', 'replicate', 'constant', 'circular'], \ - "mode should be one of constant, reflect, replicate, circular, but got {}.".format(mode) + dtype = helper.input_dtype(input_param_name='input') + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='pad3d', inputs=inputs, outputs={"Out": out}, attrs=attrs) - dtype = helper.input_dtype(input_param_name='input') - out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type='pad3d', inputs=inputs, outputs={"Out": out}, attrs=attrs) if len(unsqueezed_dim) != 0: out = squeeze(out, axes=unsqueezed_dim) return out + + +def cosine_similarity(x1, x2, dim=1, eps=1e-8): + """ + Compute cosine similarity between x1 and x2 along dim. + + Parameters: + x1 (Variable): First input. float32/double. + x2 (Variable): Second input. float32/double. + dim (int): Dimension of vectors to compute cosine similarity. Default is 1. + eps(float): Small value to avoid division by zero. Default is 1e-8. + + Returns: a Tensor representing cosine similarity between x1 and x2 along dim. + Return Type: Variable + + Examples: + .. code-block:: text + Case 0: + x1 = [[0.8024077 0.9927354 0.27238318 0.8344984 ] + [0.48949873 0.5797396 0.65444374 0.66510963] + [0.1031398 0.9614342 0.08365563 0.6796464 ] + [0.10760343 0.7461209 0.7726148 0.5801006 ]] + x2 = [[0.62913156 0.1536727 0.9847992 0.04591406] + [0.9098952 0.15715368 0.8671125 0.3156102 ] + [0.4427798 0.54136837 0.5276275 0.32394758] + [0.3769419 0.8535014 0.48041078 0.9256797 ]] + dim = 1 + eps = 1e-8 + Out: [0.5275037 0.8368967 0.75037485 0.9245899] + + Code Examples: + .. code-block:: python + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + np.random.seed(0) + x1 = np.random.rand(2,3) + x2 = np.random.rand(2,3) + x1 = paddle.to_variable(x1) + x2 = paddle.to_variable(x2) + result = paddle.nn.functional.cosine_similarity(x1, x2, dim=0) + print(result.numpy()) + # [0.99806249 0.9817672 0.94987036] + + """ + w12 = sum(elementwise_mul(x1, x2), dim=dim) + w1 = sum(elementwise_mul(x1, x1), dim=dim) + w2 = sum(elementwise_mul(x2, x2), dim=dim) + n12 = sqrt(clamp(w1 * w2, min=eps * eps)) + cos_sim = w12 / n12 + return cos_sim diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index b6933a1545cc9..19ff068ef1101 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -41,6 +41,13 @@ from .common import ReflectionPad1d #DEFINE_ALIAS from .common import ReplicationPad1d #DEFINE_ALIAS from .common import ConstantPad1d #DEFINE_ALIAS +from .common import ReflectionPad2d #DEFINE_ALIAS +from .common import ReplicationPad2d #DEFINE_ALIAS +from .common import ConstantPad2d #DEFINE_ALIAS +from .common import ZeroPad2d #DEFINE_ALIAS +from .common import ReplicationPad3d #DEFINE_ALIAS +from .common import ConstantPad3d #DEFINE_ALIAS +from .common import CosineSimilarity from .common import Embedding #DEFINE_ALIAS from .common import Linear #DEFINE_ALIAS from .common import Flatten #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 7c40ea45166a6..20a48ef181400 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -22,8 +22,23 @@ from .. import functional as F __all__ = [ - 'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample', 'Pad', - 'Pad2D', 'ReflectionPad1d', 'ReplicationPad1d', 'ConstantPad1d' + 'BilinearTensorProduct', + 'Pool2D', + 'Embedding', + 'Linear', + 'UpSample', + 'Pad', + 'Pad2D', + 'ReflectionPad1d', + 'ReplicationPad1d', + 'ConstantPad1d', + 'ReflectionPad2d', + 'ReplicationPad2d', + 'ConstantPad2d', + 'ZeroPad2d', + 'ConstantPad3d', + 'ReplicationPad3d', + 'CosineSimilarity', ] @@ -369,8 +384,8 @@ class Pad(layers.Layer): Examples: .. code-block:: text - Input = [[[[[1., 2., 3.], - [4., 5., 6.]]]]] + x = [[[[[1., 2., 3.], + [4., 5., 6.]]]]] Case 0: pad = [2, 2, 1, 1, 0, 0], @@ -407,20 +422,23 @@ class Pad(layers.Layer): Code Examples: .. code-block:: python - + import paddle import paddle.fluid as fluid import paddle.nn as nn import numpy as np - data = np.ones((1, 1, 2, 2)).astype('float32') - my_pad = nn.Pad(pad=[1, 1, 1, 1]) - with fluid.dygraph.guard(): - data = fluid.dygraph.to_variable(data) - result = my_pad(data) - print(result.numpy()) - # [[[[0. 0. 0. 0.] - # [0. 1. 1. 0.] - # [0. 1. 1. 0.] - # [0. 0. 0. 0.]]]] + paddle.disable_static() + + x_shape = (1, 1, 3, 4) + x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + 1 + tensor_x = paddle.to_variable(x) + my_pad = nn.Pad2D(paddings=[1, 1, 1, 1]) + y = my_pad(tensor_x) + print(y.numpy()) + # [[[[ 0. 0. 0. 0. 0. 0.] + # [ 0. 1. 2. 3. 4. 0.] + # [ 0. 5. 6. 7. 8. 0.] + # [ 0. 9. 10. 11. 12. 0.] + # [ 0. 0. 0. 0. 0. 0.]]]] """ def __init__(self, @@ -436,8 +454,8 @@ def __init__(self, self._pad = pad self._name = name - def forward(self, input): - return F.pad(input, + def forward(self, x): + return F.pad(x, pad=self._pad, mode=self._mode, value=self._value, @@ -447,15 +465,12 @@ def forward(self, input): class ReflectionPad1d(layers.Layer): """ - :alias_main: paddle.nn.ReflectionPad1d - :alias: paddle.nn.ReflectionPad1d - This interface is used to construct a callable object of the ``ReflectionPad1d`` class. Uses reflection of the input boundaries to pad the input tensor. Parameters: pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0]. data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. Default is "NCL" name (str, optional) : The default value is None. Normally there is no need for @@ -467,8 +482,8 @@ class ReflectionPad1d(layers.Layer): Examples: .. code-block:: text - Input = [[[1., 2., 3.], - [4., 5., 6.]]] + x = [[[1., 2., 3.], + [4., 5., 6.]]] pad = [1, 2], Out = [[[2. 1. 2. 3. 2. 1.] [5. 4. 5. 6. 5. 4.]]] @@ -476,16 +491,17 @@ class ReflectionPad1d(layers.Layer): Code Examples: .. code-block:: python - import paddle.fluid as fluid + import paddle import paddle.nn as nn import numpy as np + paddle.disable_static() + input_shape = (1, 2, 3) pad = [1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ReflectionPad1d(padding=pad) - with fluid.dygraph.guard(): - data = fluid.dygraph.to_variable(data) - result = my_pad(data) + data = paddle.to_variable(data) + result = my_pad(data) print(result.numpy()) # [[[2. 1. 2. 3. 2. 1.] # [5. 4. 5. 6. 5. 4.]]] @@ -498,8 +514,8 @@ def __init__(self, padding=[0, 0], data_format="NCL", name=None): self._pad = padding self._name = name - def forward(self, input): - return F.pad(input, + def forward(self, x): + return F.pad(x, pad=self._pad, mode=self._mode, data_format=self._data_format, @@ -508,15 +524,12 @@ def forward(self, input): class ReplicationPad1d(layers.Layer): """ - :alias_main: paddle.nn.ReplicationPad1d - :alias: paddle.nn.ReplicationPad1d - This interface is used to construct a callable object of the ``ReplicationPad1d`` class. Uses input boundaries to pad the input tensor. Parameters: pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0]. data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. Default is "NCL" name (str, optional) : The default value is None. Normally there is no need for @@ -528,25 +541,26 @@ class ReplicationPad1d(layers.Layer): Examples: .. code-block:: text - Input = [[[1., 2., 3.], - [4., 5., 6.]]] + x = [[[1., 2., 3.], + [4., 5., 6.]]] pad = [1, 2], Out = [[[2. 1. 2. 3. 2. 1.] [5. 4. 5. 6. 5. 4.]]] Code Examples: .. code-block:: python - - import paddle.fluid as fluid + + import paddle import paddle.nn as nn import numpy as np + paddle.disable_static() + input_shape = (1, 2, 3) pad = [1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ReplicationPad1d(padding=pad) - with fluid.dygraph.guard(): - data = fluid.dygraph.to_variable(data) - result = my_pad(data) + data = paddle.to_variable(data) + result = my_pad(data) print(result.numpy()) # [[[1. 1. 2. 3. 3. 3.] # [1. 4. 5. 6. 6. 6.]]] @@ -559,8 +573,8 @@ def __init__(self, padding=[0, 0], data_format="NCL", name=None): self._pad = padding self._name = name - def forward(self, input): - return F.pad(input, + def forward(self, x): + return F.pad(x, pad=self._pad, mode=self._mode, data_format=self._data_format, @@ -569,15 +583,13 @@ def forward(self, input): class ConstantPad1d(layers.Layer): """ - :alias_main: paddle.nn.ConstantPad1d - :alias: paddle.nn.ConstantPad1d - This interface is used to construct a callable object of the ``ConstantPad1d`` class. Uses a constant value to pad the input tensor. Parameters: pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0]. + value (float32): The value to fill the padded areas. Default is 0.0 data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. Default is "NCL" name (str, optional) : The default value is None. Normally there is no need for @@ -589,8 +601,8 @@ class ConstantPad1d(layers.Layer): Examples: .. code-block:: text - Input = [[[1., 2., 3.], - [4., 5., 6.]]] + x = [[[1., 2., 3.], + [4., 5., 6.]]] pad = [1, 2], value = 0.0 Out = [[[0. 1. 2. 3. 0. 0.] @@ -598,18 +610,18 @@ class ConstantPad1d(layers.Layer): Code Examples: .. code-block:: python - - import paddle.fluid as fluid + + import paddle import paddle.nn as nn import numpy as np + paddle.disable_static() + input_shape = (1, 2, 3) pad = [1, 2] - value = 0.0 data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 - my_pad = nn.ConstantPad1d(padding=pad, value=value) - with fluid.dygraph.guard(): - data = fluid.dygraph.to_variable(data) - result = my_pad(data) + my_pad = nn.ConstantPad1d(padding=pad) + data = paddle.to_variable(data) + result = my_pad(data) print(result.numpy()) # [[[0. 1. 2. 3. 0. 0.] # [0. 4. 5. 6. 0. 0.]]] @@ -623,10 +635,458 @@ def __init__(self, padding=[0, 0], value=0.0, data_format="NCL", name=None): self._value = value self._name = name - def forward(self, input): - return F.pad(input, + def forward(self, x): + return F.pad(x, + pad=self._pad, + mode=self._mode, + value=self._value, + data_format=self._data_format, + name=self._name) + + +class ConstantPad2d(layers.Layer): + """ + This interface is used to construct a callable object of the ``ConstantPad2d`` class. + Uses a constant value to pad the input tensor. + + Parameters: + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + value (float32): The value to fill the padded areas. Default is 0.0 + data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. + Default is "NCHW" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + None + + Examples: + .. code-block:: text + + x = [[[[1., 2., 3.], + [4., 5., 6.]]]] + pad = [1, 1, 0, 0] + value = 0.0 + Out = [[[[0. 1. 2. 3. 0.] + [0. 4. 5. 6. 0.]]]] + + Code Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + input_shape = (1, 1, 2, 3) + pad = [1, 0, 1, 2] + data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + my_pad = nn.ConstantPad2d(padding=pad) + data = paddle.to_variable(data) + result = my_pad(data) + print(result.numpy()) + # [[[[0. 0. 0. 0.] + # [0. 1. 2. 3.] + # [0. 4. 5. 6.] + # [0. 0. 0. 0.] + # [0. 0. 0. 0.]]]] + """ + + def __init__(self, + padding=[0, 0, 0, 0], + value=0.0, + data_format="NCHW", + name=None): + super(ConstantPad2d, self).__init__() + self._mode = "constant" + self._data_format = data_format + self._pad = padding + self._value = value + self._name = name + + def forward(self, x): + return F.pad(x, + pad=self._pad, + mode=self._mode, + value=self._value, + data_format=self._data_format, + name=self._name) + + +class ZeroPad2d(layers.Layer): + """ + This interface is used to construct a callable object of the ``ZeroPad2d`` class. + Uses 0 to pad the input tensor. + + Parameters: + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. + Default is "NCHW" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + None + + Examples: + .. code-block:: text + + x = [[[[1., 2., 3.], + [4., 5., 6.]]]] + pad = [1, 1, 0, 0] + Out = [[[[0. 1. 2. 3. 0.] + [0. 4. 5. 6. 0.]]]] + + Code Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + input_shape = (1, 1, 2, 3) + pad = [1, 0, 1, 2] + data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + my_pad = nn.ZeroPad2d(padding=pad) + data = paddle.to_variable(data) + result = my_pad(data) + print(result.numpy()) + # [[[[0. 0. 0. 0.] + # [0. 1. 2. 3.] + # [0. 4. 5. 6.] + # [0. 0. 0. 0.] + # [0. 0. 0. 0.]]]] + """ + + def __init__(self, padding=[0, 0, 0, 0], data_format="NCHW", name=None): + super(ZeroPad2d, self).__init__() + self._mode = "constant" + self._data_format = data_format + self._pad = padding + self._name = name + + def forward(self, x): + return F.pad(x, + pad=self._pad, + mode=self._mode, + data_format=self._data_format, + name=self._name) + + +class ReplicationPad2d(layers.Layer): + """ + This interface is used to construct a callable object of the ``ReplicationPad2d`` class. + Uses input boundaries to pad the input tensor. + + Parameters: + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. + Default is "NCHW" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + None + + Examples: + .. code-block:: text + + x = [[[[1., 2., 3.], + [4., 5., 6.]]]] + pad = [1, 1, 0, 0] + Out = [[[[1. 1. 2. 3. 3.] + [4. 4. 5. 6. 6.]]]] + + Code Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + input_shape = (1, 1, 2, 3) + pad = [1, 0, 1, 2] + data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + my_pad = nn.ReplicationPad2d(padding=pad) + data = paddle.to_variable(data) + result = my_pad(data) + print(result.numpy()) + # [[[[1. 1. 2. 3.] + # [1. 1. 2. 3.] + # [4. 4. 5. 6.] + # [4. 4. 5. 6.] + # [4. 4. 5. 6.]]]] + """ + + def __init__(self, padding=[0, 0, 0, 0], data_format="NCHW", name=None): + super(ReplicationPad2d, self).__init__() + self._mode = "replicate" + self._data_format = data_format + self._pad = padding + self._name = name + + def forward(self, x): + return F.pad(x, + pad=self._pad, + mode=self._mode, + data_format=self._data_format, + name=self._name) + + +class ReflectionPad2d(layers.Layer): + """ + This interface is used to construct a callable object of the ``ReflectionPad2d`` class. + Uses reflection of the input boundaries to pad the input tensor. + + Parameters: + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. + Default is "NCHW" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + None + + Examples: + .. code-block:: text + + x = [[[[1., 2., 3.], + [4., 5., 6.]]]] + pad = [1, 1, 0, 0] + Out = [[[[2. 1. 2. 3. 2.] + [5. 4. 5. 6. 5.]]]] + + Code Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + input_shape = (1, 1, 4, 3) + pad = [1, 0, 1, 2] + data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + my_pad = nn.ReflectionPad2d(padding=pad) + data = paddle.to_variable(data) + result = my_pad(data) + print(result.numpy()) + # [[[[ 5. 4. 5. 6.] + # [ 2. 1. 2. 3.] + # [ 5. 4. 5. 6.] + # [ 8. 7. 8. 9.] + # [11. 10. 11. 12.] + # [ 8. 7. 8. 9.] + # [ 5. 4. 5. 6.]]]] + """ + + def __init__(self, padding=[0, 0, 0, 0], data_format="NCHW", name=None): + super(ReflectionPad2d, self).__init__() + self._mode = "reflect" + self._data_format = data_format + self._pad = padding + self._name = name + + def forward(self, x): + return F.pad(x, + pad=self._pad, + mode=self._mode, + data_format=self._data_format, + name=self._name) + + +class ConstantPad3d(layers.Layer): + """ + This interface is used to construct a callable object of the ``ConstantPad3d`` class. + Uses a constant value to pad the input tensor. + + Parameters: + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0, 0, 0]. + value (float32): The value to fill the padded areas. Default is 0.0 + data_format (str): An string from: "NCDHW", "NDHWC". Specify the data format of the input data. + Default is "NCDHW" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + None + + Examples: + .. code-block:: text + + x = [[[[[1., 2., 3.], + [4., 5., 6.]]]]] + pad = [1, 2, 0, 0, 0, 0] + value = 0.0 + Out = [[[[[0. 1. 2. 3. 0. 0.] + [0. 4. 5. 6. 0. 0.]]]]] + + Code Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + input_shape = (1, 1, 1, 2, 3) + pad = [1, 0, 1, 2, 0, 0] + data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + my_pad = nn.ConstantPad3d(padding=pad) + data = paddle.to_variable(data) + result = my_pad(data) + print(result.numpy()) + # [[[[[0. 0. 0. 0.] + # [0. 1. 2. 3.] + # [0. 4. 5. 6.] + # [0. 0. 0. 0.] + # [0. 0. 0. 0.]]]]] + """ + + def __init__(self, + padding=[0, 0, 0, 0, 0, 0], + value=0.0, + data_format="NCDHW", + name=None): + super(ConstantPad3d, self).__init__() + self._mode = "constant" + self._data_format = data_format + self._pad = padding + self._value = value + self._name = name + + def forward(self, x): + return F.pad(x, pad=self._pad, mode=self._mode, value=self._value, data_format=self._data_format, name=self._name) + + +class ReplicationPad3d(layers.Layer): + """ + This interface is used to construct a callable object of the ``ReplicationPad3d`` class. + Uses input boundaries to pad the input tensor. + + Parameters: + pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0, 0, 0]. + data_format (str): An string from: "NCDHW", "NDHWC". Specify the data format of the input data. + Default is "NCDHW" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + None + + Examples: + .. code-block:: text + + x = [[[[[1., 2., 3.], + [4., 5., 6.]]]]] + pad = [1, 2, 0, 0, 0, 0] + Out = [[[[[1. 1. 2. 3. 3. 3.] + [4. 4. 5. 6. 6. 6.]]]]] + + Code Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + input_shape = (1, 1, 1, 2, 3) + pad = [1, 0, 1, 2, 0, 0] + data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 + my_pad = nn.ReplicationPad3d(padding=pad) + data = paddle.to_variable(data) + result = my_pad(data) + print(result.numpy()) + # [[[[[1. 1. 2. 3.] + # [1. 1. 2. 3.] + # [4. 4. 5. 6.] + # [4. 4. 5. 6.] + # [4. 4. 5. 6.]]]]] + """ + + def __init__(self, + padding=[0, 0, 0, 0, 0, 0], + data_format="NCDHW", + name=None): + super(ReplicationPad3d, self).__init__() + self._mode = "replicate" + self._data_format = data_format + self._pad = padding + self._name = name + + def forward(self, x): + return F.pad(x, + pad=self._pad, + mode=self._mode, + data_format=self._data_format, + name=self._name) + + +class CosineSimilarity(layers.Layer): + """ + This interface is used to compute cosine similarity between x1 and x2 along dim. + + Parameters: + dim (int): Dimension of vectors to compute cosine similarity. Default is 1. + eps(float): Small value to avoid division by zero. Default is 1e-8. + Returns: + None + + Examples: + .. code-block:: text + + Case 0: + x1 = [[0.8024077 0.9927354 0.27238318 0.8344984 ] + [0.48949873 0.5797396 0.65444374 0.66510963] + [0.1031398 0.9614342 0.08365563 0.6796464 ] + [0.10760343 0.7461209 0.7726148 0.5801006 ]] + x2 = [[0.62913156 0.1536727 0.9847992 0.04591406] + [0.9098952 0.15715368 0.8671125 0.3156102 ] + [0.4427798 0.54136837 0.5276275 0.32394758] + [0.3769419 0.8535014 0.48041078 0.9256797 ]] + dim = 1 + eps = 1e-8 + Out: [0.5275037 0.8368967 0.75037485 0.9245899] + + Code Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + np.random.seed(0) + x1 = np.random.rand(2,3) + x2 = np.random.rand(2,3) + x1 = paddle.to_variable(x1) + x2 = paddle.to_variable(x2) + + cos_sim_func = nn.CosineSimilarity(dim=0) + result = cos_sim_func(x1, x2) + print(result.numpy()) + # [0.99806249 0.9817672 0.94987036] + """ + + def __init__(self, dim=1, eps=1e-8): + super(CosineSimilarity, self).__init__() + self._dim = dim + self._eps = eps + + def forward(self, x1, x2): + return F.cosine_similarity(x1, x2, dim=self._dim, eps=self._eps) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 21cae803716a9..c2134280902c4 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -111,6 +111,7 @@ from .math import elementwise_div #DEFINE_ALIAS from .math import elementwise_floordiv #DEFINE_ALIAS from .math import elementwise_max #DEFINE_ALIAS +from .math import elementwise_mul #DEFINE_ALIAS from .math import elementwise_min #DEFINE_ALIAS from .math import elementwise_mod #DEFINE_ALIAS from .math import elementwise_pow #DEFINE_ALIAS From 85f975b2635ccc9ddf4eb498a9d74f94a02a004b Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 17 Aug 2020 05:37:10 +0000 Subject: [PATCH 09/15] test=develop, remove padding default value --- python/paddle/nn/__init__.py | 1 - python/paddle/nn/functional/common.py | 32 ++--- python/paddle/nn/layer/__init__.py | 3 +- python/paddle/nn/layer/common.py | 192 +++++--------------------- 4 files changed, 47 insertions(+), 181 deletions(-) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index e32c30573c0bb..8ce12788b7ee8 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -61,7 +61,6 @@ from .layer.common import BilinearTensorProduct #DEFINE_ALIAS from .layer.common import Pool2D #DEFINE_ALIAS from .layer.common import Pad2D #DEFINE_ALIAS -from .layer.common import Pad #DEFINE_ALIAS from .layer.common import ReflectionPad1d #DEFINE_ALIAS from .layer.common import ReplicationPad1d #DEFINE_ALIAS from .layer.common import ConstantPad1d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index db90f81d056a5..a0844fcbbfc19 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -25,12 +25,12 @@ from ...fluid.layers import pad2d #DEFINE_ALIAS from ...fluid.layers import unfold #DEFINE_ALIAS from ...fluid.layers import assign #DEFINE_ALIAS -from ...fluid.layers import squeeze -from ...fluid.layers import unsqueeze -from ...fluid.layers import elementwise_mul -from ...tensor import clamp -from ...tensor import sum -from ...tensor import sqrt +from ...fluid.layers import squeeze #DEFINE_ALIAS +from ...fluid.layers import unsqueeze #DEFINE_ALIAS +from ...fluid.layers import elementwise_mul #DEFINE_ALIAS +from ...tensor import clamp #DEFINE_ALIAS +from ...tensor import sum #DEFINE_ALIAS +from ...tensor import sqrt #DEFINE_ALIAS #from ...fluid.layers import fc #DEFINE_ALIAS from ...fluid.layers import pad_constant_like #DEFINE_ALIAS @@ -455,12 +455,7 @@ def _is_list_or_turple_(data): return out -def pad(x, - pad=[0, 0, 0, 0], - mode='constant', - value=0, - data_format="NCHW", - name=None): +def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): """ Pad tensor according to 'pad' and 'mode'. If mode is 'reflect', pad[0] and pad[1] must be no greater @@ -472,7 +467,7 @@ def pad(x, of input will be padded. 1. If input dimension is 3, then the pad has the form (pad_left, pad_right). 2. If the input dimension is 4, then the pad has the form (pad_left, pad_right, pad_top, pad_bottom). 3. If the input dimension is 5, then the pad has the form - (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). Default is [0, 0, 0, 0]. + (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. When in 'constant' mode, this op uses a constant value to pad the input tensor. @@ -560,6 +555,9 @@ def pad(x, "mode should be one of constant, reflect, replicate, circular, but got {}.".format(mode) data_format = data_format.upper() + assert data_format in ["NCL", "NCHW", "NCDHW", "NLC", "NHWC", "NDHWC"], \ + "data_format should be in one of [NCL, NCHW, NCDHW, NLC, NHWC, NDHWC], " \ + "but got {}".format(data_format) x_dim = len(x.shape) @@ -587,10 +585,6 @@ def pad(x, pad = concat([pad, zeros((2, ), dtype="int32")], axis=0) unsqueezed_dim = [1] x = unsqueeze(x, axes=unsqueezed_dim) - else: - raise ValueError, "data_format should be in one of " - "[NCL, NCHW, NCDHW, NLC, NHWC, NDHWC] but got {}".format( - data_format) else: if data_format in ["NCL", "NCHW", "NCDHW"]: data_format = "NCDHW" @@ -612,10 +606,6 @@ def pad(x, pad = pad + [0, 0] unsqueezed_dim = [1] x = unsqueeze(x, axes=unsqueezed_dim) - else: - raise ValueError, "data_format should be in one of " - "[NCL, NCHW, NCDHW, NLC, NHWC, NDHWC] but got {}".format( - data_format) if in_dygraph_mode(): if isinstance(pad, Variable): diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 19ff068ef1101..2e568af8161f6 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -37,7 +37,6 @@ from .common import BilinearTensorProduct #DEFINE_ALIAS from .common import Pool2D #DEFINE_ALIAS from .common import Pad2D #DEFINE_ALIAS -from .common import Pad #DEFINE_ALIAS from .common import ReflectionPad1d #DEFINE_ALIAS from .common import ReplicationPad1d #DEFINE_ALIAS from .common import ConstantPad1d #DEFINE_ALIAS @@ -47,7 +46,7 @@ from .common import ZeroPad2d #DEFINE_ALIAS from .common import ReplicationPad3d #DEFINE_ALIAS from .common import ConstantPad3d #DEFINE_ALIAS -from .common import CosineSimilarity +from .common import CosineSimilarity #DEFINE_ALIAS from .common import Embedding #DEFINE_ALIAS from .common import Linear #DEFINE_ALIAS from .common import Flatten #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 20a48ef181400..286b85e3f3468 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -27,7 +27,6 @@ 'Embedding', 'Linear', 'UpSample', - 'Pad', 'Pad2D', 'ReflectionPad1d', 'ReplicationPad1d', @@ -349,128 +348,14 @@ def forward(self, input): data_format=self._data_format) -class Pad(layers.Layer): - """ - :alias_main: paddle.nn.Pad - :alias: paddle.nn.Pad - - This interface is used to construct a callable object of the ``Pad`` class. - The Pad layer pads the input tensor boundaries according to 'pad' and 'mode'. - If mode is 'reflect', pad[0] and pad[1] must be no greater - than width-1. The height and depth dimensions have the same condition. - - Parameters: - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. 1. If input dimension is 3, then the pad has the form (pad_left, - pad_right). 2. If the input dimension is 4, then the pad has the form (pad_left, pad_right, - pad_top, pad_bottom). 3. If the input dimension is 5, then the pad has the form - (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). Default is [0, 0, 0, 0]. - mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. - When in 'constant' mode, this op uses a constant value to pad the input tensor. - When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor. - When in 'replicate' mode, uses input boundaries to pad the input tensor. - When in 'circular' mode, uses circular input to pad the input tensor. - Default is 'constant' - value (float32): The value to fill the padded areas in 'constant' mode . Default is 0.0 - data_format (str): An string from: "NCL", "NLC", NHWC", "NCHW", "NCDHW", "NDHWC". Specify the data format of - the input data. - Default is "NCHW" - name (str, optional) : The default value is None. Normally there is no need for - user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - None - - Examples: - .. code-block:: text - - x = [[[[[1., 2., 3.], - [4., 5., 6.]]]]] - - Case 0: - pad = [2, 2, 1, 1, 0, 0], - mode = 'constant' - pad_value = 0 - Out = [[[[[0. 0. 0. 0. 0. 0. 0.] - [0. 0. 1. 2. 3. 0. 0.] - [0. 0. 4. 5. 6. 0. 0.] - [0. 0. 0. 0. 0. 0. 0.]]]]] - - Case 1: - pad = [2, 2, 1, 1, 0, 0], - mode = 'reflect' - Out = [[[[[6. 5. 4. 5. 6. 5. 4.] - [3. 2. 1. 2. 3. 2. 1.] - [6. 5. 4. 5. 6. 5. 4.] - [3. 2. 1. 2. 3. 2. 1.]]]]] - - Case 2: - pad = [2, 2, 1, 1, 0, 0], - mode = 'replicate' - Out = [[[[[1. 1. 1. 2. 3. 3. 3.] - [1. 1. 1. 2. 3. 3. 3.] - [4. 4. 4. 5. 6. 6. 6.] - [4. 4. 4. 5. 6. 6. 6.]]]]] - - Case 3: - pad = [2, 2, 1, 1, 0, 0], - mode = 'circular' - Out = [[[[[5. 6. 4. 5. 6. 4. 5.] - [2. 3. 1. 2. 3. 1. 2.] - [5. 6. 4. 5. 6. 4. 5.] - [2. 3. 1. 2. 3. 1. 2.]]]]] - - Code Examples: - .. code-block:: python - import paddle - import paddle.fluid as fluid - import paddle.nn as nn - import numpy as np - paddle.disable_static() - - x_shape = (1, 1, 3, 4) - x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + 1 - tensor_x = paddle.to_variable(x) - my_pad = nn.Pad2D(paddings=[1, 1, 1, 1]) - y = my_pad(tensor_x) - print(y.numpy()) - # [[[[ 0. 0. 0. 0. 0. 0.] - # [ 0. 1. 2. 3. 4. 0.] - # [ 0. 5. 6. 7. 8. 0.] - # [ 0. 9. 10. 11. 12. 0.] - # [ 0. 0. 0. 0. 0. 0.]]]] - """ - - def __init__(self, - pad=[0, 0, 0, 0], - mode='constant', - value=0.0, - data_format="NCHW", - name=None): - super(Pad, self).__init__() - self._mode = mode - self._value = value - self._data_format = data_format - self._pad = pad - self._name = name - - def forward(self, x): - return F.pad(x, - pad=self._pad, - mode=self._mode, - value=self._value, - data_format=self._data_format, - name=self._name) - - class ReflectionPad1d(layers.Layer): """ This interface is used to construct a callable object of the ``ReflectionPad1d`` class. Uses reflection of the input boundaries to pad the input tensor. Parameters: - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0]. + padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. Default is "NCL" name (str, optional) : The default value is None. Normally there is no need for @@ -484,7 +369,7 @@ class ReflectionPad1d(layers.Layer): x = [[[1., 2., 3.], [4., 5., 6.]]] - pad = [1, 2], + padding = [1, 2], Out = [[[2. 1. 2. 3. 2. 1.] [5. 4. 5. 6. 5. 4.]]] @@ -507,7 +392,7 @@ class ReflectionPad1d(layers.Layer): # [5. 4. 5. 6. 5. 4.]]] """ - def __init__(self, padding=[0, 0], data_format="NCL", name=None): + def __init__(self, padding, data_format="NCL", name=None): super(ReflectionPad1d, self).__init__() self._mode = "reflect" self._data_format = data_format @@ -528,8 +413,8 @@ class ReplicationPad1d(layers.Layer): Uses input boundaries to pad the input tensor. Parameters: - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0]. + padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. Default is "NCL" name (str, optional) : The default value is None. Normally there is no need for @@ -543,7 +428,7 @@ class ReplicationPad1d(layers.Layer): x = [[[1., 2., 3.], [4., 5., 6.]]] - pad = [1, 2], + padding = [1, 2], Out = [[[2. 1. 2. 3. 2. 1.] [5. 4. 5. 6. 5. 4.]]] @@ -566,7 +451,7 @@ class ReplicationPad1d(layers.Layer): # [1. 4. 5. 6. 6. 6.]]] """ - def __init__(self, padding=[0, 0], data_format="NCL", name=None): + def __init__(self, padding, data_format="NCL", name=None): super(ReplicationPad1d, self).__init__() self._mode = "replicate" self._data_format = data_format @@ -587,8 +472,8 @@ class ConstantPad1d(layers.Layer): Uses a constant value to pad the input tensor. Parameters: - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0]. + padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right). value (float32): The value to fill the padded areas. Default is 0.0 data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. Default is "NCL" @@ -603,7 +488,7 @@ class ConstantPad1d(layers.Layer): x = [[[1., 2., 3.], [4., 5., 6.]]] - pad = [1, 2], + padding = [1, 2], value = 0.0 Out = [[[0. 1. 2. 3. 0. 0.] [0. 4. 5. 6. 0. 0.]]] @@ -627,7 +512,7 @@ class ConstantPad1d(layers.Layer): # [0. 4. 5. 6. 0. 0.]]] """ - def __init__(self, padding=[0, 0], value=0.0, data_format="NCL", name=None): + def __init__(self, padding, value=0.0, data_format="NCL", name=None): super(ConstantPad1d, self).__init__() self._mode = "constant" self._data_format = data_format @@ -650,8 +535,8 @@ class ConstantPad2d(layers.Layer): Uses a constant value to pad the input tensor. Parameters: - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom). value (float32): The value to fill the padded areas. Default is 0.0 data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. Default is "NCHW" @@ -666,7 +551,7 @@ class ConstantPad2d(layers.Layer): x = [[[[1., 2., 3.], [4., 5., 6.]]]] - pad = [1, 1, 0, 0] + padding = [1, 1, 0, 0] value = 0.0 Out = [[[[0. 1. 2. 3. 0.] [0. 4. 5. 6. 0.]]]] @@ -720,8 +605,8 @@ class ZeroPad2d(layers.Layer): Uses 0 to pad the input tensor. Parameters: - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom). data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. Default is "NCHW" name (str, optional) : The default value is None. Normally there is no need for @@ -735,7 +620,7 @@ class ZeroPad2d(layers.Layer): x = [[[[1., 2., 3.], [4., 5., 6.]]]] - pad = [1, 1, 0, 0] + padding = [1, 1, 0, 0] Out = [[[[0. 1. 2. 3. 0.] [0. 4. 5. 6. 0.]]]] @@ -761,7 +646,7 @@ class ZeroPad2d(layers.Layer): # [0. 0. 0. 0.]]]] """ - def __init__(self, padding=[0, 0, 0, 0], data_format="NCHW", name=None): + def __init__(self, padding, data_format="NCHW", name=None): super(ZeroPad2d, self).__init__() self._mode = "constant" self._data_format = data_format @@ -782,8 +667,8 @@ class ReplicationPad2d(layers.Layer): Uses input boundaries to pad the input tensor. Parameters: - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom). data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. Default is "NCHW" name (str, optional) : The default value is None. Normally there is no need for @@ -797,7 +682,7 @@ class ReplicationPad2d(layers.Layer): x = [[[[1., 2., 3.], [4., 5., 6.]]]] - pad = [1, 1, 0, 0] + padding = [1, 1, 0, 0] Out = [[[[1. 1. 2. 3. 3.] [4. 4. 5. 6. 6.]]]] @@ -823,7 +708,7 @@ class ReplicationPad2d(layers.Layer): # [4. 4. 5. 6.]]]] """ - def __init__(self, padding=[0, 0, 0, 0], data_format="NCHW", name=None): + def __init__(self, padding, data_format="NCHW", name=None): super(ReplicationPad2d, self).__init__() self._mode = "replicate" self._data_format = data_format @@ -844,8 +729,8 @@ class ReflectionPad2d(layers.Layer): Uses reflection of the input boundaries to pad the input tensor. Parameters: - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0]. + padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom). data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. Default is "NCHW" name (str, optional) : The default value is None. Normally there is no need for @@ -859,7 +744,7 @@ class ReflectionPad2d(layers.Layer): x = [[[[1., 2., 3.], [4., 5., 6.]]]] - pad = [1, 1, 0, 0] + padding = [1, 1, 0, 0] Out = [[[[2. 1. 2. 3. 2.] [5. 4. 5. 6. 5.]]]] @@ -887,7 +772,7 @@ class ReflectionPad2d(layers.Layer): # [ 5. 4. 5. 6.]]]] """ - def __init__(self, padding=[0, 0, 0, 0], data_format="NCHW", name=None): + def __init__(self, padding, data_format="NCHW", name=None): super(ReflectionPad2d, self).__init__() self._mode = "reflect" self._data_format = data_format @@ -908,8 +793,8 @@ class ConstantPad3d(layers.Layer): Uses a constant value to pad the input tensor. Parameters: - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0, 0, 0]. + padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). value (float32): The value to fill the padded areas. Default is 0.0 data_format (str): An string from: "NCDHW", "NDHWC". Specify the data format of the input data. Default is "NCDHW" @@ -924,7 +809,7 @@ class ConstantPad3d(layers.Layer): x = [[[[[1., 2., 3.], [4., 5., 6.]]]]] - pad = [1, 2, 0, 0, 0, 0] + padding = [1, 2, 0, 0, 0, 0] value = 0.0 Out = [[[[[0. 1. 2. 3. 0. 0.] [0. 4. 5. 6. 0. 0.]]]]] @@ -951,11 +836,7 @@ class ConstantPad3d(layers.Layer): # [0. 0. 0. 0.]]]]] """ - def __init__(self, - padding=[0, 0, 0, 0, 0, 0], - value=0.0, - data_format="NCDHW", - name=None): + def __init__(self, padding, value=0.0, data_format="NCDHW", name=None): super(ConstantPad3d, self).__init__() self._mode = "constant" self._data_format = data_format @@ -978,8 +859,8 @@ class ReplicationPad3d(layers.Layer): Uses input boundaries to pad the input tensor. Parameters: - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions - of input will be padded. The pad has the form (pad_left, pad_right). Default is [0, 0, 0, 0, 0, 0]. + padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). data_format (str): An string from: "NCDHW", "NDHWC". Specify the data format of the input data. Default is "NCDHW" name (str, optional) : The default value is None. Normally there is no need for @@ -993,7 +874,7 @@ class ReplicationPad3d(layers.Layer): x = [[[[[1., 2., 3.], [4., 5., 6.]]]]] - pad = [1, 2, 0, 0, 0, 0] + padding = [1, 2, 0, 0, 0, 0] Out = [[[[[1. 1. 2. 3. 3. 3.] [4. 4. 5. 6. 6. 6.]]]]] @@ -1019,10 +900,7 @@ class ReplicationPad3d(layers.Layer): # [4. 4. 5. 6.]]]]] """ - def __init__(self, - padding=[0, 0, 0, 0, 0, 0], - data_format="NCDHW", - name=None): + def __init__(self, padding, data_format="NCDHW", name=None): super(ReplicationPad3d, self).__init__() self._mode = "replicate" self._data_format = data_format From b3563822da36d0aa36309dbc4264440d0a80d898 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 17 Aug 2020 08:35:05 +0000 Subject: [PATCH 10/15] test=develop, rename var to tensor --- .../unittests/test_cosine_similarity_api.py | 121 ++++++++++ .../fluid/tests/unittests/test_pad3d_op.py | 219 +++++++++++++----- python/paddle/nn/functional/common.py | 15 +- python/paddle/nn/layer/common.py | 20 +- 4 files changed, 292 insertions(+), 83 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_cosine_similarity_api.py diff --git a/python/paddle/fluid/tests/unittests/test_cosine_similarity_api.py b/python/paddle/fluid/tests/unittests/test_cosine_similarity_api.py new file mode 100644 index 0000000000000..a0c26b512955e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cosine_similarity_api.py @@ -0,0 +1,121 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.fluid.core as core + +from paddle.fluid import Program, program_guard, Executor, default_main_program + + +class TestCosineSimilarityAPI(unittest.TestCase): + def setUp(self): + self.places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def _get_numpy_out(self, x1, x2, dim=1, eps=1e-8): + w12 = np.sum(x1 * x2, axis=dim) + w1 = np.sum(x1 * x1, axis=dim) + w2 = np.sum(x2 * x2, axis=dim) + n12 = np.sqrt(np.clip(w1 * w2, eps * eps, None)) + cos_sim = w12 / n12 + return cos_sim + + def check_static_result(self, place): + paddle.enable_static() + + with program_guard(Program(), Program()): + shape = [10, 15] + dim = 1 + eps = 1e-8 + np.random.seed(0) + np_x1 = np.random.rand(*shape).astype(np.float32) + np_x2 = np.random.rand(*shape).astype(np.float32) + + x1 = paddle.data(name="x1", shape=shape) + x2 = paddle.data(name="x2", shape=shape) + result = F.cosine_similarity(x1, x2, dim=dim, eps=eps) + exe = Executor(place) + fetches = exe.run(default_main_program(), + feed={"x1": np_x1, + "x2": np_x2}, + fetch_list=[result]) + + np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps) + self.assertTrue(np.allclose(fetches[0], np_out)) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph_1(self): + paddle.disable_static() + + shape = [10, 15] + dim = 1 + eps = 1e-8 + np.random.seed(1) + np_x1 = np.random.rand(*shape).astype(np.float32) + np_x2 = np.random.rand(*shape).astype(np.float32) + np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps) + + tesnor_x1 = paddle.to_variable(np_x1) + tesnor_x2 = paddle.to_variable(np_x2) + y = F.cosine_similarity(tesnor_x1, tesnor_x2, dim=dim, eps=eps) + + self.assertTrue(np.allclose(y.numpy(), np_out)) + + def test_dygraph_2(self): + paddle.disable_static() + + shape = [12, 13] + dim = 0 + eps = 1e-6 + np.random.seed(1) + np_x1 = np.random.rand(*shape).astype(np.float32) + np_x2 = np.random.rand(*shape).astype(np.float32) + np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps) + + tesnor_x1 = paddle.to_variable(np_x1) + tesnor_x2 = paddle.to_variable(np_x2) + y = F.cosine_similarity(tesnor_x1, tesnor_x2, dim=dim, eps=eps) + + self.assertTrue(np.allclose(y.numpy(), np_out)) + + def test_dygraph_3(self): + paddle.disable_static() + + shape1 = [10, 12, 10] + shape2 = [10, 1, 10] + dim = 2 + eps = 1e-6 + np.random.seed(1) + np_x1 = np.random.rand(*shape1).astype(np.float32) + np_x2 = np.random.rand(*shape2).astype(np.float32) + np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps) + + tesnor_x1 = paddle.to_variable(np_x1) + tesnor_x2 = paddle.to_variable(np_x2) + y = F.cosine_similarity(tesnor_x1, tesnor_x2, dim=dim, eps=eps) + + self.assertTrue(np.allclose(y.numpy(), np_out)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index 8b7eeaf53ffff..fcf0ea25faf17 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -18,7 +18,6 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -import paddle.fluid.dygraph as dg import paddle.fluid.core as core from paddle.fluid import Program, program_guard, Executor, default_main_program @@ -26,6 +25,7 @@ class TestPad3dOp(OpTest): def setUp(self): + paddle.enable_static() self.value = 0.0 self.variable_paddings = False self.initTestCase() @@ -160,7 +160,30 @@ def initTestCase(self): self.variable_paddings = True -class TestPad3dDygraph(unittest.TestCase): +class TestPadAPI(unittest.TestCase): + def setUp(self): + self.places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def check_static_result(self, place): + paddle.enable_static() + with program_guard(Program(), Program()): + input_shape = (1, 2, 3, 4, 5) + pad = [1, 2, 1, 1, 3, 4] + mode = "constant" + value = 100 + input_data = np.random.rand(*input_shape).astype(np.float32) + x = paddle.data(name="x", shape=input_shape) + result = F.pad(x=x, pad=pad, value=value, mode='constant') + exe = Executor(place) + fetches = exe.run(default_main_program(), + feed={"x": input_data}, + fetch_list=[result]) + + np_out = self._get_numpy_out(input_data, pad, mode, value) + self.assertTrue(np.allclose(fetches[0], np_out)) + def _get_numpy_out(self, input_data, pad, mode, value, data_format="NCDHW"): if data_format == "NCDHW": pad = [ @@ -190,36 +213,42 @@ def _get_numpy_out(self, input_data, pad, mode, value, data_format="NCDHW"): return out + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + def test_dygraph(self): + paddle.disable_static() + input_shape = (1, 2, 3, 4, 5) pad = [1, 2, 1, 1, 3, 4] mode = "constant" value = 100 input_data = np.random.rand(*input_shape).astype(np.float32) np_out = self._get_numpy_out(input_data, pad, mode, value) - place = paddle.CPUPlace() - with dg.guard(place) as g: - input = dg.to_variable(input_data) - output = F.pad(x=input, pad=pad, mode=mode, value=value) - self.assertTrue(np.allclose(output.numpy(), np_out)) + tensor_data = paddle.to_variable(input_data) + y = F.pad(tensor_data, pad=pad, mode=mode, value=value) + self.assertTrue(np.allclose(y.numpy(), np_out)) -class TestPadAPI(unittest.TestCase): - def _get_numpy_out(self, input_data, pad, mode, value, data_format="NCDHW"): - if data_format == "NCDHW": + +class TestPad1dAPI(unittest.TestCase): + def _get_numpy_out(self, + input_data, + pad, + mode, + value=0.0, + data_format="NCL"): + if data_format == "NCL": pad = [ (0, 0), (0, 0), - (pad[4], pad[5]), - (pad[2], pad[3]), (pad[0], pad[1]), ] else: pad = [ (0, 0), - (pad[4], pad[5]), - (pad[2], pad[3]), (pad[0], pad[1]), (0, 0), ] @@ -230,8 +259,6 @@ def _get_numpy_out(self, input_data, pad, mode, value, data_format="NCDHW"): out = np.pad(input_data, pad, mode=mode) elif mode == "replicate": out = np.pad(input_data, pad, mode="edge") - elif mode == "circular": - out = np.pad(input_data, pad, mode="wrap") return out @@ -240,58 +267,128 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(paddle.CUDAPlace(0)) - def check_static_result(self, place): - with program_guard(Program(), Program()): - input_shape = (1, 2, 3, 4, 5) - pad = [1, 2, 1, 1, 3, 4] - mode = "constant" + def test_class(self): + paddle.disable_static() + for place in self.places: + input_shape = (3, 4, 5) + pad = [1, 2] value = 100 input_data = np.random.rand(*input_shape).astype(np.float32) - x = paddle.data(name="x", shape=input_shape) - result = F.pad(x=x, pad=pad, value=value, mode='constant') - exe = Executor(place) - fetches = exe.run(default_main_program(), - feed={"x": input_data}, - fetch_list=[result]) - np_out = self._get_numpy_out(input_data, pad, mode, value) + pad_reflection = nn.ReflectionPad1d(padding=pad) + pad_replication = nn.ReplicationPad1d(padding=pad) + pad_constant = nn.ConstantPad1d(padding=pad, value=value) - self.assertTrue(np.allclose(fetches[0], np_out)) + data = paddle.to_variable(input_data) - def test_static(self): - for place in self.places: - self.check_static_result(place=place) + output = pad_reflection(data) + np_out = self._get_numpy_out( + input_data, pad, "reflect", data_format="NCL") + self.assertTrue(np.allclose(output.numpy(), np_out)) - def test_dygraph(self): + output = pad_replication(data) + np_out = self._get_numpy_out( + input_data, pad, "replicate", data_format="NCL") + self.assertTrue(np.allclose(output.numpy(), np_out)) + + output = pad_constant(data) + np_out = self._get_numpy_out( + input_data, pad, "constant", value=value, data_format="NCL") + self.assertTrue(np.allclose(output.numpy(), np_out)) + + +class TestPad2dAPI(unittest.TestCase): + def _get_numpy_out(self, + input_data, + pad, + mode, + value=0.0, + data_format="NCHW"): + if data_format == "NCHW": + pad = [ + (0, 0), + (0, 0), + (pad[2], pad[3]), + (pad[0], pad[1]), + ] + else: + pad = [ + (0, 0), + (pad[2], pad[3]), + (pad[0], pad[1]), + (0, 0), + ] + + if mode == "constant": + out = np.pad(input_data, pad, mode=mode, constant_values=value) + elif mode == "reflect": + out = np.pad(input_data, pad, mode=mode) + elif mode == "replicate": + out = np.pad(input_data, pad, mode="edge") + + return out + + def setUp(self): + self.places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def test_class(self): + paddle.disable_static() for place in self.places: - input_shape = (1, 2, 3, 4, 5) - pad = [1, 2, 1, 1, 3, 4] - mode = "constant" + input_shape = (3, 4, 5, 6) + pad = [1, 2, 2, 1] value = 100 input_data = np.random.rand(*input_shape).astype(np.float32) - np_out = self._get_numpy_out(input_data, pad, mode, value) - with dg.guard(place) as g: - input = dg.to_variable(input_data) - output = F.pad(x=input, pad=pad, mode=mode, value=value) - self.assertTrue(np.allclose(output.numpy(), np_out)) + + pad_reflection = nn.ReflectionPad2d(padding=pad) + pad_replication = nn.ReplicationPad2d(padding=pad) + pad_constant = nn.ConstantPad2d(padding=pad, value=value) + pad_zero = nn.ZeroPad2d(padding=pad) + + data = paddle.to_variable(input_data) + + output = pad_reflection(data) + np_out = self._get_numpy_out( + input_data, pad, "reflect", data_format="NCHW") + self.assertTrue(np.allclose(output.numpy(), np_out)) + + output = pad_replication(data) + np_out = self._get_numpy_out( + input_data, pad, "replicate", data_format="NCHW") + self.assertTrue(np.allclose(output.numpy(), np_out)) + + output = pad_constant(data) + np_out = self._get_numpy_out( + input_data, pad, "constant", value=value, data_format="NCHW") + self.assertTrue(np.allclose(output.numpy(), np_out)) + + output = pad_zero(data) + np_out = self._get_numpy_out( + input_data, pad, "constant", value=0, data_format="NCHW") + self.assertTrue(np.allclose(output.numpy(), np_out)) -class TestPad1dClass(unittest.TestCase): +class TestPad3dAPI(unittest.TestCase): def _get_numpy_out(self, input_data, pad, mode, value=0.0, - data_format="NCL"): - if data_format == "NCL": + data_format="NCDHW"): + if data_format == "NCDHW": pad = [ (0, 0), (0, 0), + (pad[4], pad[5]), + (pad[2], pad[3]), (pad[0], pad[1]), ] else: pad = [ (0, 0), + (pad[4], pad[5]), + (pad[2], pad[3]), (pad[0], pad[1]), (0, 0), ] @@ -311,33 +408,27 @@ def setUp(self): self.places.append(paddle.CUDAPlace(0)) def test_class(self): + paddle.disable_static() for place in self.places: - input_shape = (3, 4, 5) - pad = [1, 2] + input_shape = (3, 4, 5, 6, 7) + pad = [1, 2, 2, 1, 1, 0] value = 100 input_data = np.random.rand(*input_shape).astype(np.float32) - pad_reflection = nn.ReflectionPad1d(padding=pad) - pad_replication = nn.ReplicationPad1d(padding=pad) - pad_constant = nn.ConstantPad1d(padding=pad, value=value) - - with dg.guard(place) as g: - data = paddle.fluid.dygraph.to_variable(input_data) + pad_replication = nn.ReplicationPad3d(padding=pad) + pad_constant = nn.ConstantPad3d(padding=pad, value=value) - output = pad_reflection(data) - np_out = self._get_numpy_out( - input_data, pad, "reflect", data_format="NCL") - self.assertTrue(np.allclose(output.numpy(), np_out)) + data = paddle.to_variable(input_data) - output = pad_replication(data) - np_out = self._get_numpy_out( - input_data, pad, "replicate", data_format="NCL") - self.assertTrue(np.allclose(output.numpy(), np_out)) + output = pad_replication(data) + np_out = self._get_numpy_out( + input_data, pad, "replicate", data_format="NCDHW") + self.assertTrue(np.allclose(output.numpy(), np_out)) - output = pad_constant(data) - np_out = self._get_numpy_out( - input_data, pad, "constant", value=value, data_format="NCL") - self.assertTrue(np.allclose(output.numpy(), np_out)) + output = pad_constant(data) + np_out = self._get_numpy_out( + input_data, pad, "constant", value=value, data_format="NCDHW") + self.assertTrue(np.allclose(output.numpy(), np_out)) class TestPad3dOpError(unittest.TestCase): diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index a0844fcbbfc19..a09035c277b8d 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -47,7 +47,8 @@ 'unfold', # 'bilinear_tensor_product', 'assign', - 'interpolate' + 'interpolate', + 'cosine_similarity', ] @@ -462,8 +463,8 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): than width-1. The height and depth dimension has the same condition. Parameters: - x (Variable): The input tensor with data type float32/double/int32/int64_t. - pad (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + x (Tensor): The input tensor with data type float32/double/int32/int64_t. + pad (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions of input will be padded. 1. If input dimension is 3, then the pad has the form (pad_left, pad_right). 2. If the input dimension is 4, then the pad has the form (pad_left, pad_right, pad_top, pad_bottom). 3. If the input dimension is 5, then the pad has the form @@ -483,7 +484,7 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: a Tensor padded according to pad and mode and data type is same as input. - Return Type: Variable + Return Type: Tensor Examples: .. code-block:: text @@ -642,13 +643,13 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8): Compute cosine similarity between x1 and x2 along dim. Parameters: - x1 (Variable): First input. float32/double. - x2 (Variable): Second input. float32/double. + x1 (Tensor): First input. float32/double. + x2 (Tensor): Second input. float32/double. dim (int): Dimension of vectors to compute cosine similarity. Default is 1. eps(float): Small value to avoid division by zero. Default is 1e-8. Returns: a Tensor representing cosine similarity between x1 and x2 along dim. - Return Type: Variable + Return Type: Tensor Examples: .. code-block:: text diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 286b85e3f3468..05ce6d40ba13a 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -354,7 +354,7 @@ class ReflectionPad1d(layers.Layer): Uses reflection of the input boundaries to pad the input tensor. Parameters: - padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + padding (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions of input will be padded. The pad has the form (pad_left, pad_right). data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. Default is "NCL" @@ -413,7 +413,7 @@ class ReplicationPad1d(layers.Layer): Uses input boundaries to pad the input tensor. Parameters: - padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + padding (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions of input will be padded. The pad has the form (pad_left, pad_right). data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. Default is "NCL" @@ -472,7 +472,7 @@ class ConstantPad1d(layers.Layer): Uses a constant value to pad the input tensor. Parameters: - padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + padding (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions of input will be padded. The pad has the form (pad_left, pad_right). value (float32): The value to fill the padded areas. Default is 0.0 data_format (str): An string from: "NCL", "NLC". Specify the data format of the input data. @@ -535,7 +535,7 @@ class ConstantPad2d(layers.Layer): Uses a constant value to pad the input tensor. Parameters: - padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + padding (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom). value (float32): The value to fill the padded areas. Default is 0.0 data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. @@ -578,11 +578,7 @@ class ConstantPad2d(layers.Layer): # [0. 0. 0. 0.]]]] """ - def __init__(self, - padding=[0, 0, 0, 0], - value=0.0, - data_format="NCHW", - name=None): + def __init__(self, padding, value=0.0, data_format="NCHW", name=None): super(ConstantPad2d, self).__init__() self._mode = "constant" self._data_format = data_format @@ -667,7 +663,7 @@ class ReplicationPad2d(layers.Layer): Uses input boundaries to pad the input tensor. Parameters: - padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + padding (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom). data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. Default is "NCHW" @@ -793,7 +789,7 @@ class ConstantPad3d(layers.Layer): Uses a constant value to pad the input tensor. Parameters: - padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + padding (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). value (float32): The value to fill the padded areas. Default is 0.0 data_format (str): An string from: "NCDHW", "NDHWC". Specify the data format of the input data. @@ -859,7 +855,7 @@ class ReplicationPad3d(layers.Layer): Uses input boundaries to pad the input tensor. Parameters: - padding (Variable | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions + padding (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). data_format (str): An string from: "NCDHW", "NDHWC". Specify the data format of the input data. Default is "NCDHW" From 9acc264842d565c6ecb98410f98f78ed717f35a7 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 17 Aug 2020 09:26:35 +0000 Subject: [PATCH 11/15] test=develop, add more tests --- .../fluid/tests/unittests/test_pad3d_op.py | 67 +++++++++++++++++-- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index fcf0ea25faf17..2af41c90d8b74 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -166,7 +166,7 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(paddle.CUDAPlace(0)) - def check_static_result(self, place): + def check_static_result_1(self, place): paddle.enable_static() with program_guard(Program(), Program()): input_shape = (1, 2, 3, 4, 5) @@ -175,7 +175,7 @@ def check_static_result(self, place): value = 100 input_data = np.random.rand(*input_shape).astype(np.float32) x = paddle.data(name="x", shape=input_shape) - result = F.pad(x=x, pad=pad, value=value, mode='constant') + result = F.pad(x=x, pad=pad, value=value, mode=mode) exe = Executor(place) fetches = exe.run(default_main_program(), feed={"x": input_data}, @@ -184,7 +184,63 @@ def check_static_result(self, place): np_out = self._get_numpy_out(input_data, pad, mode, value) self.assertTrue(np.allclose(fetches[0], np_out)) - def _get_numpy_out(self, input_data, pad, mode, value, data_format="NCDHW"): + def check_static_result_2(self, place): + paddle.enable_static() + with program_guard(Program(), Program()): + input_shape = (2, 3, 4, 5, 6) + pad = [1, 2, 1, 1, 1, 2] + mode = "reflect" + input_data = np.random.rand(*input_shape).astype(np.float32) + x = paddle.data(name="x", shape=input_shape) + result = F.pad(x=x, pad=pad, mode=mode) + exe = Executor(place) + fetches = exe.run(default_main_program(), + feed={"x": input_data}, + fetch_list=[result]) + + np_out = self._get_numpy_out(input_data, pad, mode) + self.assertTrue(np.allclose(fetches[0], np_out)) + + def check_static_result_3(self, place): + paddle.enable_static() + with program_guard(Program(), Program()): + input_shape = (2, 3, 4, 5, 6) + pad = [1, 2, 1, 1, 3, 4] + mode = "replicate" + input_data = np.random.rand(*input_shape).astype(np.float32) + x = paddle.data(name="x", shape=input_shape) + result = F.pad(x=x, pad=pad, mode=mode) + exe = Executor(place) + fetches = exe.run(default_main_program(), + feed={"x": input_data}, + fetch_list=[result]) + + np_out = self._get_numpy_out(input_data, pad, mode) + self.assertTrue(np.allclose(fetches[0], np_out)) + + def check_static_result_4(self, place): + paddle.enable_static() + with program_guard(Program(), Program()): + input_shape = (2, 3, 4, 5, 6) + pad = [1, 2, 1, 1, 3, 4] + mode = "circular" + input_data = np.random.rand(*input_shape).astype(np.float32) + x = paddle.data(name="x", shape=input_shape) + result = F.pad(x=x, pad=pad, mode=mode) + exe = Executor(place) + fetches = exe.run(default_main_program(), + feed={"x": input_data}, + fetch_list=[result]) + + np_out = self._get_numpy_out(input_data, pad, mode) + self.assertTrue(np.allclose(fetches[0], np_out)) + + def _get_numpy_out(self, + input_data, + pad, + mode, + value=0, + data_format="NCDHW"): if data_format == "NCDHW": pad = [ (0, 0), @@ -215,7 +271,10 @@ def _get_numpy_out(self, input_data, pad, mode, value, data_format="NCDHW"): def test_static(self): for place in self.places: - self.check_static_result(place=place) + self.check_static_result_1(place=place) + self.check_static_result_2(place=place) + self.check_static_result_3(place=place) + self.check_static_result_4(place=place) def test_dygraph(self): From a4abac943764e86558578f4cecef8d3f48504499 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 17 Aug 2020 11:54:48 +0000 Subject: [PATCH 12/15] test=develop, rename tovar to totensor --- .../fluid/tests/unittests/test_pad3d_op.py | 8 +++---- python/paddle/nn/functional/common.py | 8 +++---- python/paddle/nn/layer/common.py | 22 +++++++++---------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index 2af41c90d8b74..eaaec9fbf982b 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -286,7 +286,7 @@ def test_dygraph(self): value = 100 input_data = np.random.rand(*input_shape).astype(np.float32) np_out = self._get_numpy_out(input_data, pad, mode, value) - tensor_data = paddle.to_variable(input_data) + tensor_data = paddle.to_tensor(input_data) y = F.pad(tensor_data, pad=pad, mode=mode, value=value) self.assertTrue(np.allclose(y.numpy(), np_out)) @@ -338,7 +338,7 @@ def test_class(self): pad_replication = nn.ReplicationPad1d(padding=pad) pad_constant = nn.ConstantPad1d(padding=pad, value=value) - data = paddle.to_variable(input_data) + data = paddle.to_tensor(input_data) output = pad_reflection(data) np_out = self._get_numpy_out( @@ -405,7 +405,7 @@ def test_class(self): pad_constant = nn.ConstantPad2d(padding=pad, value=value) pad_zero = nn.ZeroPad2d(padding=pad) - data = paddle.to_variable(input_data) + data = paddle.to_tensor(input_data) output = pad_reflection(data) np_out = self._get_numpy_out( @@ -477,7 +477,7 @@ def test_class(self): pad_replication = nn.ReplicationPad3d(padding=pad) pad_constant = nn.ConstantPad3d(padding=pad, value=value) - data = paddle.to_variable(input_data) + data = paddle.to_tensor(input_data) output = pad_replication(data) np_out = self._get_numpy_out( diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index a09035c277b8d..b86731654e3f7 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -536,7 +536,7 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): # example 1 x_shape = (1, 1, 3) x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + 1 - tensor_x = paddle.to_variable(x) + tensor_x = paddle.to_tensor(x) y = F.pad(tensor_x, pad=[2, 3], value=1, mode='constant') print(y.numpy()) # [[[1. 1. 1. 2. 3. 1. 1. 1.]]] @@ -544,7 +544,7 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): # example 2 x_shape = (1, 1, 2, 3) x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + 1 - tensor_x = paddle.to_variable(x) + tensor_x = paddle.to_tensor(x) y = F.pad(tensor_x, pad=[1, 2, 1, 1], value=1, mode='circular') print(y.numpy()) # [[[[6. 4. 5. 6. 4. 5.] @@ -676,8 +676,8 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8): np.random.seed(0) x1 = np.random.rand(2,3) x2 = np.random.rand(2,3) - x1 = paddle.to_variable(x1) - x2 = paddle.to_variable(x2) + x1 = paddle.to_tensor(x1) + x2 = paddle.to_tensor(x2) result = paddle.nn.functional.cosine_similarity(x1, x2, dim=0) print(result.numpy()) # [0.99806249 0.9817672 0.94987036] diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 05ce6d40ba13a..c4823298f2035 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -385,7 +385,7 @@ class ReflectionPad1d(layers.Layer): pad = [1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ReflectionPad1d(padding=pad) - data = paddle.to_variable(data) + data = paddle.to_tensor(data) result = my_pad(data) print(result.numpy()) # [[[2. 1. 2. 3. 2. 1.] @@ -444,7 +444,7 @@ class ReplicationPad1d(layers.Layer): pad = [1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ReplicationPad1d(padding=pad) - data = paddle.to_variable(data) + data = paddle.to_tensor(data) result = my_pad(data) print(result.numpy()) # [[[1. 1. 2. 3. 3. 3.] @@ -505,7 +505,7 @@ class ConstantPad1d(layers.Layer): pad = [1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ConstantPad1d(padding=pad) - data = paddle.to_variable(data) + data = paddle.to_tensor(data) result = my_pad(data) print(result.numpy()) # [[[0. 1. 2. 3. 0. 0.] @@ -568,7 +568,7 @@ class ConstantPad2d(layers.Layer): pad = [1, 0, 1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ConstantPad2d(padding=pad) - data = paddle.to_variable(data) + data = paddle.to_tensor(data) result = my_pad(data) print(result.numpy()) # [[[[0. 0. 0. 0.] @@ -632,7 +632,7 @@ class ZeroPad2d(layers.Layer): pad = [1, 0, 1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ZeroPad2d(padding=pad) - data = paddle.to_variable(data) + data = paddle.to_tensor(data) result = my_pad(data) print(result.numpy()) # [[[[0. 0. 0. 0.] @@ -694,7 +694,7 @@ class ReplicationPad2d(layers.Layer): pad = [1, 0, 1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ReplicationPad2d(padding=pad) - data = paddle.to_variable(data) + data = paddle.to_tensor(data) result = my_pad(data) print(result.numpy()) # [[[[1. 1. 2. 3.] @@ -756,7 +756,7 @@ class ReflectionPad2d(layers.Layer): pad = [1, 0, 1, 2] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ReflectionPad2d(padding=pad) - data = paddle.to_variable(data) + data = paddle.to_tensor(data) result = my_pad(data) print(result.numpy()) # [[[[ 5. 4. 5. 6.] @@ -822,7 +822,7 @@ class ConstantPad3d(layers.Layer): pad = [1, 0, 1, 2, 0, 0] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ConstantPad3d(padding=pad) - data = paddle.to_variable(data) + data = paddle.to_tensor(data) result = my_pad(data) print(result.numpy()) # [[[[[0. 0. 0. 0.] @@ -886,7 +886,7 @@ class ReplicationPad3d(layers.Layer): pad = [1, 0, 1, 2, 0, 0] data = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) + 1 my_pad = nn.ReplicationPad3d(padding=pad) - data = paddle.to_variable(data) + data = paddle.to_tensor(data) result = my_pad(data) print(result.numpy()) # [[[[[1. 1. 2. 3.] @@ -948,8 +948,8 @@ class CosineSimilarity(layers.Layer): np.random.seed(0) x1 = np.random.rand(2,3) x2 = np.random.rand(2,3) - x1 = paddle.to_variable(x1) - x2 = paddle.to_variable(x2) + x1 = paddle.to_tensor(x1) + x2 = paddle.to_tensor(x2) cos_sim_func = nn.CosineSimilarity(dim=0) result = cos_sim_func(x1, x2) From 835f82a6805deca8f73b914f2b2e134e65324a0f Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 17 Aug 2020 13:10:21 +0000 Subject: [PATCH 13/15] test=develop, fix init --- python/paddle/tensor/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ef2f852ac90c2..9b8616eabe5b4 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -109,9 +109,7 @@ from .math import elementwise_add #DEFINE_ALIAS from .math import elementwise_div #DEFINE_ALIAS from .math import elementwise_floordiv #DEFINE_ALIAS -from .math import elementwise_max #DEFINE_ALIAS from .math import elementwise_mul #DEFINE_ALIAS -from .math import elementwise_min #DEFINE_ALIAS from .math import elementwise_mod #DEFINE_ALIAS from .math import elementwise_pow #DEFINE_ALIAS from .math import elementwise_sub #DEFINE_ALIAS From fafd8b9da7fb4fd1682d1e4e0efec7e833588ad8 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 18 Aug 2020 02:43:02 +0000 Subject: [PATCH 14/15] test=develop, add more test --- .../fluid/tests/unittests/test_pad3d_op.py | 197 ++++++++++-------- 1 file changed, 111 insertions(+), 86 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index eaaec9fbf982b..516b568bf1fb0 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -85,79 +85,63 @@ def initTestCase(self): self.pad_value = 0.0 -class TestCase1(TestPad3dOp): - def initTestCase(self): - self.shape = (2, 3, 4, 5, 6) - self.paddings = [0, 1, 2, 3, 4, 5] - self.mode = "constant" - self.data_format = "NCDHW" - self.value = 1.0 - - -class TestCase2(TestPad3dOp): - def initTestCase(self): - self.shape = (2, 3, 4, 5, 6) - self.paddings = [1, 1, 1, 1, 1, 1] - self.mode = "constant" - self.data_format = "NDHWC" - self.value = 1.0 - - -class TestCase3(TestPad3dOp): - def initTestCase(self): - self.shape = (2, 3, 4, 5, 6) - self.paddings = [0, 1, 1, 0, 2, 3] - self.mode = "reflect" - self.data_format = "NCDHW" - - -class TestCase4(TestPad3dOp): - def initTestCase(self): - self.shape = (4, 4, 4, 4, 4) - self.paddings = [0, 1, 2, 1, 2, 3] - self.mode = "reflect" - self.data_format = "NDHWC" - - -class TestCase5(TestPad3dOp): - def initTestCase(self): - self.shape = (2, 3, 4, 5, 6) - self.paddings = [0, 1, 2, 3, 2, 1] - self.mode = "replicate" - self.data_format = "NCDHW" - - -class TestCase6(TestPad3dOp): - def initTestCase(self): - self.shape = (4, 4, 4, 4, 4) - self.paddings = [5, 4, 2, 1, 2, 3] - self.mode = "replicate" - self.data_format = "NDHWC" - - -class TestCase7(TestPad3dOp): - def initTestCase(self): - self.shape = (2, 3, 4, 5, 6) - self.paddings = [0, 1, 2, 3, 2, 1] - self.mode = "circular" - self.data_format = "NCDHW" - - -class TestCase8(TestPad3dOp): - def initTestCase(self): - self.shape = (4, 4, 4, 4, 4) - self.paddings = [0, 1, 2, 1, 2, 3] - self.mode = "circular" - self.data_format = "NDHWC" - - -class TestCase9(TestPad3dOp): - def initTestCase(self): - self.shape = (2, 3, 4, 5, 6) - self.paddings = [0, 1, 2, 3, 3, 1] - self.mode = "reflect" - self.data_format = "NCDHW" - self.variable_paddings = True +# class TestCase1(TestPad3dOp): +# def initTestCase(self): +# self.shape = (2, 3, 4, 5, 6) +# self.paddings = [0, 1, 2, 3, 4, 5] +# self.mode = "constant" +# self.data_format = "NCDHW" +# self.value = 1.0 + +# class TestCase2(TestPad3dOp): +# def initTestCase(self): +# self.shape = (2, 3, 4, 5, 6) +# self.paddings = [1, 1, 1, 1, 1, 1] +# self.mode = "constant" +# self.data_format = "NDHWC" +# self.value = 1.0 + +# class TestCase3(TestPad3dOp): +# def initTestCase(self): +# self.shape = (2, 3, 4, 5, 6) +# self.paddings = [0, 1, 1, 0, 2, 3] +# self.mode = "reflect" +# self.data_format = "NCDHW" + +# class TestCase4(TestPad3dOp): +# def initTestCase(self): +# self.shape = (4, 4, 4, 4, 4) +# self.paddings = [0, 1, 2, 1, 2, 3] +# self.mode = "reflect" +# self.data_format = "NDHWC" + +# class TestCase5(TestPad3dOp): +# def initTestCase(self): +# self.shape = (2, 3, 4, 5, 6) +# self.paddings = [0, 1, 2, 3, 2, 1] +# self.mode = "replicate" +# self.data_format = "NCDHW" + +# class TestCase6(TestPad3dOp): +# def initTestCase(self): +# self.shape = (4, 4, 4, 4, 4) +# self.paddings = [5, 4, 2, 1, 2, 3] +# self.mode = "replicate" +# self.data_format = "NDHWC" + +# class TestCase7(TestPad3dOp): +# def initTestCase(self): +# self.shape = (2, 3, 4, 5, 6) +# self.paddings = [0, 1, 2, 3, 2, 1] +# self.mode = "circular" +# self.data_format = "NCDHW" + +# class TestCase8(TestPad3dOp): +# def initTestCase(self): +# self.shape = (4, 4, 4, 4, 4) +# self.paddings = [0, 1, 2, 1, 2, 3] +# self.mode = "circular" +# self.data_format = "NDHWC" class TestPadAPI(unittest.TestCase): @@ -192,14 +176,19 @@ def check_static_result_2(self, place): mode = "reflect" input_data = np.random.rand(*input_shape).astype(np.float32) x = paddle.data(name="x", shape=input_shape) - result = F.pad(x=x, pad=pad, mode=mode) + result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW") + result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC") exe = Executor(place) fetches = exe.run(default_main_program(), feed={"x": input_data}, - fetch_list=[result]) + fetch_list=[result1, result2]) - np_out = self._get_numpy_out(input_data, pad, mode) - self.assertTrue(np.allclose(fetches[0], np_out)) + np_out1 = self._get_numpy_out( + input_data, pad, mode, data_format="NCDHW") + np_out2 = self._get_numpy_out( + input_data, pad, mode, data_format="NDHWC") + self.assertTrue(np.allclose(fetches[0], np_out1)) + self.assertTrue(np.allclose(fetches[1], np_out2)) def check_static_result_3(self, place): paddle.enable_static() @@ -209,14 +198,19 @@ def check_static_result_3(self, place): mode = "replicate" input_data = np.random.rand(*input_shape).astype(np.float32) x = paddle.data(name="x", shape=input_shape) - result = F.pad(x=x, pad=pad, mode=mode) + result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW") + result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC") exe = Executor(place) fetches = exe.run(default_main_program(), feed={"x": input_data}, - fetch_list=[result]) + fetch_list=[result1, result2]) - np_out = self._get_numpy_out(input_data, pad, mode) - self.assertTrue(np.allclose(fetches[0], np_out)) + np_out1 = self._get_numpy_out( + input_data, pad, mode, data_format="NCDHW") + np_out2 = self._get_numpy_out( + input_data, pad, mode, data_format="NDHWC") + self.assertTrue(np.allclose(fetches[0], np_out1)) + self.assertTrue(np.allclose(fetches[1], np_out2)) def check_static_result_4(self, place): paddle.enable_static() @@ -226,14 +220,19 @@ def check_static_result_4(self, place): mode = "circular" input_data = np.random.rand(*input_shape).astype(np.float32) x = paddle.data(name="x", shape=input_shape) - result = F.pad(x=x, pad=pad, mode=mode) + result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW") + result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC") exe = Executor(place) fetches = exe.run(default_main_program(), feed={"x": input_data}, - fetch_list=[result]) + fetch_list=[result1, result2]) - np_out = self._get_numpy_out(input_data, pad, mode) - self.assertTrue(np.allclose(fetches[0], np_out)) + np_out1 = self._get_numpy_out( + input_data, pad, mode, data_format="NCDHW") + np_out2 = self._get_numpy_out( + input_data, pad, mode, data_format="NDHWC") + self.assertTrue(np.allclose(fetches[0], np_out1)) + self.assertTrue(np.allclose(fetches[1], np_out2)) def _get_numpy_out(self, input_data, @@ -249,7 +248,7 @@ def _get_numpy_out(self, (pad[2], pad[3]), (pad[0], pad[1]), ] - else: + elif data_format == "NDHWC": pad = [ (0, 0), (pad[4], pad[5]), @@ -257,6 +256,32 @@ def _get_numpy_out(self, (pad[0], pad[1]), (0, 0), ] + elif data_format == "NCHW": + pad = [ + (0, 0), + (0, 0), + (pad[2], pad[3]), + (pad[0], pad[1]), + ] + elif data_format == "NHWC": + pad = [ + (0, 0), + (pad[2], pad[3]), + (pad[0], pad[1]), + (0, 0), + ] + elif data_format == "NCL": + pad = [ + (0, 0), + (0, 0), + (pad[0], pad[1]), + ] + elif data_format == "NLC": + pad = [ + (0, 0), + (pad[0], pad[1]), + (0, 0), + ] if mode == "constant": out = np.pad(input_data, pad, mode=mode, constant_values=value) From ad220f7587749de3978b560735661e29b6f3165f Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 18 Aug 2020 03:58:23 +0000 Subject: [PATCH 15/15] test=develop, add more tests --- .../fluid/tests/unittests/test_pad3d_op.py | 232 +++++++++++++----- python/paddle/nn/functional/common.py | 9 +- 2 files changed, 173 insertions(+), 68 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index 516b568bf1fb0..68589e6d8182f 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -85,63 +85,70 @@ def initTestCase(self): self.pad_value = 0.0 -# class TestCase1(TestPad3dOp): -# def initTestCase(self): -# self.shape = (2, 3, 4, 5, 6) -# self.paddings = [0, 1, 2, 3, 4, 5] -# self.mode = "constant" -# self.data_format = "NCDHW" -# self.value = 1.0 - -# class TestCase2(TestPad3dOp): -# def initTestCase(self): -# self.shape = (2, 3, 4, 5, 6) -# self.paddings = [1, 1, 1, 1, 1, 1] -# self.mode = "constant" -# self.data_format = "NDHWC" -# self.value = 1.0 - -# class TestCase3(TestPad3dOp): -# def initTestCase(self): -# self.shape = (2, 3, 4, 5, 6) -# self.paddings = [0, 1, 1, 0, 2, 3] -# self.mode = "reflect" -# self.data_format = "NCDHW" - -# class TestCase4(TestPad3dOp): -# def initTestCase(self): -# self.shape = (4, 4, 4, 4, 4) -# self.paddings = [0, 1, 2, 1, 2, 3] -# self.mode = "reflect" -# self.data_format = "NDHWC" - -# class TestCase5(TestPad3dOp): -# def initTestCase(self): -# self.shape = (2, 3, 4, 5, 6) -# self.paddings = [0, 1, 2, 3, 2, 1] -# self.mode = "replicate" -# self.data_format = "NCDHW" - -# class TestCase6(TestPad3dOp): -# def initTestCase(self): -# self.shape = (4, 4, 4, 4, 4) -# self.paddings = [5, 4, 2, 1, 2, 3] -# self.mode = "replicate" -# self.data_format = "NDHWC" - -# class TestCase7(TestPad3dOp): -# def initTestCase(self): -# self.shape = (2, 3, 4, 5, 6) -# self.paddings = [0, 1, 2, 3, 2, 1] -# self.mode = "circular" -# self.data_format = "NCDHW" - -# class TestCase8(TestPad3dOp): -# def initTestCase(self): -# self.shape = (4, 4, 4, 4, 4) -# self.paddings = [0, 1, 2, 1, 2, 3] -# self.mode = "circular" -# self.data_format = "NDHWC" +class TestCase1(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 2, 3, 4, 5] + self.mode = "constant" + self.data_format = "NCDHW" + self.value = 1.0 + + +class TestCase2(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [1, 1, 1, 1, 1, 1] + self.mode = "constant" + self.data_format = "NDHWC" + self.value = 1.0 + + +class TestCase3(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 1, 0, 2, 3] + self.mode = "reflect" + self.data_format = "NCDHW" + + +class TestCase4(TestPad3dOp): + def initTestCase(self): + self.shape = (4, 4, 4, 4, 4) + self.paddings = [0, 1, 2, 1, 2, 3] + self.mode = "reflect" + self.data_format = "NDHWC" + + +class TestCase5(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 2, 3, 2, 1] + self.mode = "replicate" + self.data_format = "NCDHW" + + +class TestCase6(TestPad3dOp): + def initTestCase(self): + self.shape = (4, 4, 4, 4, 4) + self.paddings = [5, 4, 2, 1, 2, 3] + self.mode = "replicate" + self.data_format = "NDHWC" + + +class TestCase7(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 2, 3, 2, 1] + self.mode = "circular" + self.data_format = "NCDHW" + + +class TestCase8(TestPad3dOp): + def initTestCase(self): + self.shape = (4, 4, 4, 4, 4) + self.paddings = [0, 1, 2, 1, 2, 3] + self.mode = "circular" + self.data_format = "NDHWC" class TestPadAPI(unittest.TestCase): @@ -301,8 +308,7 @@ def test_static(self): self.check_static_result_3(place=place) self.check_static_result_4(place=place) - def test_dygraph(self): - + def test_dygraph_1(self): paddle.disable_static() input_shape = (1, 2, 3, 4, 5) @@ -310,11 +316,113 @@ def test_dygraph(self): mode = "constant" value = 100 input_data = np.random.rand(*input_shape).astype(np.float32) - np_out = self._get_numpy_out(input_data, pad, mode, value) + np_out1 = self._get_numpy_out( + input_data, pad, mode, value, data_format="NCDHW") + np_out2 = self._get_numpy_out( + input_data, pad, mode, value, data_format="NDHWC") + tensor_data = paddle.to_tensor(input_data) + + y1 = F.pad(tensor_data, + pad=pad, + mode=mode, + value=value, + data_format="NCDHW") + y2 = F.pad(tensor_data, + pad=pad, + mode=mode, + value=value, + data_format="NDHWC") + + self.assertTrue(np.allclose(y1.numpy(), np_out1)) + self.assertTrue(np.allclose(y2.numpy(), np_out2)) + + def test_dygraph_2(self): + paddle.disable_static() + + input_shape = (2, 3, 4, 5) + pad = [1, 1, 3, 4] + mode = "constant" + value = 100 + input_data = np.random.rand(*input_shape).astype(np.float32) + np_out1 = self._get_numpy_out( + input_data, pad, mode, value, data_format="NCHW") + np_out2 = self._get_numpy_out( + input_data, pad, mode, value, data_format="NHWC") + tensor_data = paddle.to_tensor(input_data) + tensor_pad = paddle.to_tensor(pad, dtype="int32") + + y1 = F.pad(tensor_data, + pad=tensor_pad, + mode=mode, + value=value, + data_format="NCHW") + y2 = F.pad(tensor_data, + pad=tensor_pad, + mode=mode, + value=value, + data_format="NHWC") + + self.assertTrue(np.allclose(y1.numpy(), np_out1)) + self.assertTrue(np.allclose(y2.numpy(), np_out2)) + + def test_dygraph_2(self): + paddle.disable_static() - y = F.pad(tensor_data, pad=pad, mode=mode, value=value) - self.assertTrue(np.allclose(y.numpy(), np_out)) + input_shape = (2, 3, 4, 5) + pad = [1, 1, 3, 4] + mode = "constant" + value = 100 + input_data = np.random.rand(*input_shape).astype(np.float32) + np_out1 = self._get_numpy_out( + input_data, pad, mode, value, data_format="NCHW") + np_out2 = self._get_numpy_out( + input_data, pad, mode, value, data_format="NHWC") + tensor_data = paddle.to_tensor(input_data) + tensor_pad = paddle.to_tensor(pad, dtype="int32") + + y1 = F.pad(tensor_data, + pad=tensor_pad, + mode=mode, + value=value, + data_format="NCHW") + y2 = F.pad(tensor_data, + pad=tensor_pad, + mode=mode, + value=value, + data_format="NHWC") + + self.assertTrue(np.allclose(y1.numpy(), np_out1)) + self.assertTrue(np.allclose(y2.numpy(), np_out2)) + + def test_dygraph_3(self): + paddle.disable_static() + + input_shape = (3, 4, 5) + pad = [3, 4] + mode = "constant" + value = 100 + input_data = np.random.rand(*input_shape).astype(np.float32) + np_out1 = self._get_numpy_out( + input_data, pad, mode, value, data_format="NCL") + np_out2 = self._get_numpy_out( + input_data, pad, mode, value, data_format="NLC") + tensor_data = paddle.to_tensor(input_data) + tensor_pad = paddle.to_tensor(pad, dtype="int32") + + y1 = F.pad(tensor_data, + pad=tensor_pad, + mode=mode, + value=value, + data_format="NCL") + y2 = F.pad(tensor_data, + pad=tensor_pad, + mode=mode, + value=value, + data_format="NLC") + + self.assertTrue(np.allclose(y1.numpy(), np_out1)) + self.assertTrue(np.allclose(y2.numpy(), np_out2)) class TestPad1dAPI(unittest.TestCase): diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index b86731654e3f7..e90db0b67d78f 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -610,12 +610,9 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): if in_dygraph_mode(): if isinstance(pad, Variable): - out = core.ops.pad3d(x, pad, "mode", mode, "value", value, - "data_format", data_format, "name", name) - else: - out = core.ops.pad3d(x, "paddings", pad, "mode", mode, "value", - value, "data_format", data_format, "name", - name) + pad = pad.numpy() + out = core.ops.pad3d(x, "paddings", pad, "mode", mode, "value", value, + "data_format", data_format, "name", name) else: attrs = {'mode': mode, 'value': value, 'data_format': data_format} inputs = {'X': [x]}