diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py new file mode 100644 index 0000000000000..c6d3c6e7d0492 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py @@ -0,0 +1,129 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from op_test import OpTest, _set_use_system_allocator +from paddle.fluid.framework import grad_var_name +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import paddle + + +class TestBatchNorm(unittest.TestCase): + def test_name(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + with fluid.dygraph.guard(p): + batch_norm1d = paddle.nn.BatchNorm1d(1, name="test") + + def test_error(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + #paddle.disable_static() + x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') + x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32') + + def error1d(): + x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') + batch_norm1d = paddle.nn.BatchNorm1d(1) + batch_norm1d(fluid.dygraph.to_variable(x_data_4)) + + def error2d(): + x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32') + batch_norm2d = paddle.nn.BatchNorm2d(1) + batch_norm2d(fluid.dygraph.to_variable(x_data_3)) + + def error3d(): + x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') + batch_norm3d = paddle.nn.BatchNorm3d(1) + batch_norm3d(fluid.dygraph.to_variable(x_data_4)) + + with fluid.dygraph.guard(p): + self.assertRaises(ValueError, error1d) + self.assertRaises(ValueError, error2d) + self.assertRaises(ValueError, error3d) + + def test_dygraph(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + shape = [4, 10, 4, 4] + + def compute_v1(x, is_test, trainable_statistics): + with fluid.dygraph.guard(p): + bn = fluid.dygraph.BatchNorm( + shape[1], + is_test=is_test, + trainable_statistics=trainable_statistics) + y = bn(fluid.dygraph.to_variable(x)) + return y.numpy() + + def compute_v2(x): + with fluid.dygraph.guard(p): + bn = paddle.nn.BatchNorm2d(shape[1]) + y = bn(fluid.dygraph.to_variable(x)) + return y.numpy() + + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x, False, False) + y2 = compute_v2(x) + self.assertTrue(np.allclose(y1, y2)) + + def test_static(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + exe = fluid.Executor(p) + shape = [4, 10, 16, 16] + + def compute_v1(x_np, is_test, trainable_statistics): + with program_guard(Program(), Program()): + bn = fluid.dygraph.BatchNorm( + shape[1], + is_test=is_test, + trainable_statistics=trainable_statistics) + x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) + y = bn(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + def compute_v2(x_np): + with program_guard(Program(), Program()): + bn = paddle.nn.BatchNorm2d(shape[1]) + x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) + y = bn(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x, False, False) + y2 = compute_v2(x) + self.assertTrue(np.allclose(y1, y2)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py new file mode 100644 index 0000000000000..654e8d6f129e1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from op_test import OpTest, _set_use_system_allocator +from paddle.fluid.framework import grad_var_name +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import paddle + + +class TestDygraphGroupNormv2(unittest.TestCase): + def test_dygraph(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + shape = [2, 6, 2, 2] + + def compute_v1(x): + with fluid.dygraph.guard(p): + gn = fluid.dygraph.GroupNorm(channels=2, groups=2) + y = gn(fluid.dygraph.to_variable(x)) + return y.numpy() + + def compute_v2(x): + with fluid.dygraph.guard(p): + gn = paddle.nn.GroupNorm(num_channels=2, num_groups=2) + y = gn(fluid.dygraph.to_variable(x)) + return y.numpy() + + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x) + y2 = compute_v2(x) + self.assertTrue(np.allclose(y1, y2)) + + def test_static(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("layer_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + exe = fluid.Executor(p) + shape = [2, 6, 2, 2] + + def compute_v1(x_np): + with program_guard(Program(), Program()): + gn = fluid.dygraph.GroupNorm(channels=2, groups=2) + x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) + y = gn(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + def compute_v2(x_np): + with program_guard(Program(), Program()): + gn = paddle.nn.GroupNorm(num_channels=2, num_groups=2) + x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) + y = gn(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x) + y2 = compute_v2(x) + self.assertTrue(np.allclose(y1, y2)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py new file mode 100644 index 0000000000000..b02ba1a584b52 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py @@ -0,0 +1,115 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from op_test import OpTest, _set_use_system_allocator +from paddle.fluid.framework import grad_var_name +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import paddle + + +class TestInstanceNorm(unittest.TestCase): + def test_error(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu( + "instance_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + + def error1d(): + x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') + instance_norm1d = paddle.nn.InstanceNorm1d(1) + instance_norm1d(fluid.dygraph.to_variable(x_data_4)) + + def error2d(): + x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32') + instance_norm2d = paddle.nn.InstanceNorm2d(1) + instance_norm2d(fluid.dygraph.to_variable(x_data_3)) + + def error3d(): + x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') + instance_norm3d = paddle.nn.BatchNorm3d(1) + instance_norm3d(fluid.dygraph.to_variable(x_data_4)) + + with fluid.dygraph.guard(p): + self.assertRaises(ValueError, error1d) + self.assertRaises(ValueError, error2d) + self.assertRaises(ValueError, error3d) + + def test_dygraph(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu( + "instance_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + shape = [4, 10, 4, 4] + + def compute_v1(x): + with fluid.dygraph.guard(p): + bn = fluid.dygraph.InstanceNorm(shape[1]) + y = bn(fluid.dygraph.to_variable(x)) + return y.numpy() + + def compute_v2(x): + with fluid.dygraph.guard(p): + bn = paddle.nn.InstanceNorm2d(shape[1]) + y = bn(fluid.dygraph.to_variable(x)) + return y.numpy() + + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x) + y2 = compute_v2(x) + self.assertTrue(np.allclose(y1, y2)) + + def test_static(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu( + "instance_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + exe = fluid.Executor(p) + shape = [4, 10, 16, 16] + + def compute_v1(x_np): + with program_guard(Program(), Program()): + ins = fluid.dygraph.InstanceNorm(shape[1]) + x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) + y = ins(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + def compute_v2(x_np): + with program_guard(Program(), Program()): + ins = paddle.nn.InstanceNorm2d(shape[1]) + x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) + y = ins(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x) + y2 = compute_v2(x) + self.assertTrue(np.allclose(y1, y2)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py new file mode 100644 index 0000000000000..f324e4bd377c6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from op_test import OpTest, _set_use_system_allocator +from paddle.fluid.framework import grad_var_name +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import paddle + + +class TestDygraphLayerNormv2(unittest.TestCase): + def test_dygraph(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("layer_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + shape = [4, 10, 4, 4] + + def compute_v1(x): + with fluid.dygraph.guard(p): + ln = fluid.dygraph.LayerNorm(shape[1:]) + y = ln(fluid.dygraph.to_variable(x)) + return y.numpy() + + def compute_v2(x): + with fluid.dygraph.guard(p): + ln = paddle.nn.LayerNorm(shape[1:]) + y = ln(fluid.dygraph.to_variable(x)) + return y.numpy() + + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x) + y2 = compute_v2(x) + self.assertTrue(np.allclose(y1, y2)) + + def test_static(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("layer_norm"): + places.append(fluid.CUDAPlace(0)) + for p in places: + exe = fluid.Executor(p) + shape = [4, 10, 16, 16] + + def compute_v1(x_np): + with program_guard(Program(), Program()): + ln = fluid.dygraph.LayerNorm(shape[1:]) + x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) + y = ln(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + def compute_v2(x_np): + with program_guard(Program(), Program()): + ln = paddle.nn.LayerNorm(shape[1:]) + x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) + y = ln(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x) + y2 = compute_v2(x) + self.assertTrue(np.allclose(y1, y2)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 290622450a958..f076fb086ca27 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -127,6 +127,12 @@ from .layer.norm import LayerNorm #DEFINE_ALIAS from .layer.norm import SpectralNorm #DEFINE_ALIAS from .layer.norm import InstanceNorm #DEFINE_ALIAS +from .layer.norm import InstanceNorm1d #DEFINE_ALIAS +from .layer.norm import InstanceNorm2d #DEFINE_ALIAS +from .layer.norm import InstanceNorm3d #DEFINE_ALIAS +from .layer.norm import BatchNorm1d #DEFINE_ALIAS +from .layer.norm import BatchNorm2d #DEFINE_ALIAS +from .layer.norm import BatchNorm3d #DEFINE_ALIAS # from .layer.rnn import RNNCell #DEFINE_ALIAS # from .layer.rnn import GRUCell #DEFINE_ALIAS # from .layer.rnn import LSTMCell #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index a952cd587be83..1e14b1bc34fcf 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -157,12 +157,12 @@ from .loss import ssd_loss #DEFINE_ALIAS from .loss import teacher_student_sigmoid_loss #DEFINE_ALIAS from .loss import ctc_loss #DEFINE_ALIAS -# from .norm import batch_norm #DEFINE_ALIAS # from .norm import data_norm #DEFINE_ALIAS # from .norm import group_norm #DEFINE_ALIAS -# from .norm import instance_norm #DEFINE_ALIAS from .norm import l2_normalize #DEFINE_ALIAS -# from .norm import layer_norm #DEFINE_ALIAS +from .norm import batch_norm #DEFINE_ALIAS +from .norm import instance_norm #DEFINE_ALIAS +from .norm import layer_norm #DEFINE_ALIAS from .norm import lrn #DEFINE_ALIAS from .norm import normalize #DEFINE_ALIAS # from .norm import spectral_norm #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 0b007041b4ab3..13e86e5712a1c 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -18,16 +18,19 @@ from ...fluid.data_feeder import check_variable_and_dtype, check_type from ...fluid.layer_helper import LayerHelper from ...fluid.framework import in_dygraph_mode, core +from ...framework import create_parameter from ...fluid.layers import l2_normalize #DEFINE_ALIAS from ...fluid.layers import lrn #DEFINE_ALIAS +from ...fluid.initializer import Constant +from ...fluid.param_attr import ParamAttr +from ...fluid import core, dygraph_utils __all__ = [ - # 'batch_norm', + 'batch_norm', # 'data_norm', - # 'group_norm', - # 'instance_norm', + 'instance_norm', 'l2_normalize', - # 'layer_norm', + 'layer_norm', 'lrn', 'normalize', # 'spectral_norm' @@ -110,3 +113,286 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None): eps = out.block.create_var(dtype=out.dtype) paddle.fill_constant([1], out.dtype, epsilon, out=eps) return paddle.elementwise_div(x, paddle.maximum(out, eps), name=name) + + +def batch_norm(x, + running_mean, + running_var, + weight, + bias, + training=False, + momentum=0.9, + epsilon=1e-05, + data_format="NCHW", + name=None): + """ + Applies Batch Normalization as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift . + + nn.functional.batch_norm is uesd for nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d. Please use above API for BatchNorm. + + Parameters: + x(Tesnor): input value. It's data type should be float32, float64. + running_mean(Tensor): running mean. + running_var(Tensor): running variance. + weight(Tensor): The weight tensor of batch_norm, can not be None. + bias(Tensor): The bias tensor of batch_norm can not be None. + epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Defalut False. + data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW" or "NCDHW". Defalut "NCHW". + name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + x = np.random.seed(123) + x = np.random.random(size=(2, 1, 2, 3)).astype('float32') + running_mean = np.random.random(size=1).astype('float32') + running_variance = np.random.random(size=1).astype('float32') + weight_data = np.random.random(size=1).astype('float32') + bias_data = np.random.random(size=1).astype('float32') + x = paddle.to_tensor(x) + rm = paddle.to_tensor(running_mean) + rv = paddle.to_tensor(running_variance) + w = paddle.to_tensor(weight_data) + b = paddle.to_tensor(bias_data) + batch_norm_out = paddle.nn.functional.batch_norm(x, rm, rv, w, b) + print batch_norm_out + """ + + assert len(x.shape) >= 2, "input dim must be larger than 1" + + # we use not training means use_global_status, more details see nn._BatchNormBase + use_global_stats = not training + # input ad out must share the memory + mean_out = running_mean + variance_out = running_var + + if in_dygraph_mode(): + # for dygraph need tuple + attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout", + data_format, "use_mkldnn", False, "fuse_with_relu", False, + "use_global_stats", use_global_stats) + batch_norm_out, _, _, _, _, _ = core.ops.batch_norm( + x, weight, bias, running_mean, running_var, mean_out, variance_out, + *attrs) + + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=None) + + check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'], + 'BatchNorm') + + # for static need dict + attrs = { + "momentum": momentum, + "epsilon": epsilon, + "data_layout": data_format, + "use_mkldnn": False, + "fuse_with_relu": False, + "use_global_stats": use_global_stats, + } + + inputs = { + "X": [x], + "Scale": [weight], + "Bias": [bias], + "Mean": [running_mean], + "Variance": [running_var] + } + + helper = LayerHelper('batch_norm', **locals()) + + dtype = x.dtype if x.dtype is not 'float16' else 'float32' + saved_mean = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + saved_variance = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + batch_norm_out = helper.create_variable_for_type_inference(dtype) + + outputs = { + "Y": [batch_norm_out], + "MeanOut": [running_mean], + "VarianceOut": [running_var], + "SavedMean": [saved_mean], + "SavedVariance": [saved_variance] + } + + helper.append_op( + type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) + + return helper.append_activation(batch_norm_out) + + +def layer_norm(x, + normalized_shape, + weight=None, + bias=None, + epsilon=1e-05, + name=None): + """ + see more detail in paddle.nn.LayerNorm + + Parameters: + x(Tensor): Input Tensor. It's data type should be float32, float64. + normalized_shape(int|list|tuple): Input shape from an expected input of + size :math:`[*, normalized_shape[0], normalized_shape[1], ..., normalized_shape[-1]]`. + If it is a single integer, this module will normalize over the last dimension + which is expected to be of that specific size. + epsilon(float, optional): The small value added to the variance to prevent + division by zero. Default: 1e-05. + weight(Tensor, optional): The weight tensor of batch_norm. Default: None. + bias(Tensor, optional): The bias tensor of batch_norm. Default: None. + name(str, optional): Name for the LayerNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + np.random.seed(123) + x_data = np.random.random(size=(2, 2, 2, 3)).astype('float32') + x = paddle.to_tensor(x_data) + layer_norm = paddle.nn.functional.layer_norm(x, x.shape[1:]) + layer_norm_out = layer_norm(x) + + print(layer_norm_out.numpy) + """ + input_shape = list(x.shape) + input_ndim = len(input_shape) + normalized_ndim = len(normalized_shape) + begin_norm_axis = input_ndim - normalized_ndim + if input_ndim < normalized_ndim or input_shape[ + begin_norm_axis:] != normalized_shape: + str_normalized_shape = str(normalized_shape) + raise ValueError('Given normalized_shape is ' + str_normalized_shape + + ', expected input with shape [*, ' + + str_normalized_shape[ + 1:] + ', but got input shape ' + str(input_shape)) + + if in_dygraph_mode(): + pre_act, _, _ = core.ops.layer_norm(x, weight, bias, 'epsilon', epsilon, + 'begin_norm_axis', begin_norm_axis) + return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) + + check_variable_and_dtype(x, 'input', ['float32', 'float64'], 'LayerNorm') + + inputs = dict() + inputs['X'] = [x] + if weight: + inputs['Scale'] = [weight] + if bias: + inputs['Bias'] = [bias] + attrs = {"epsilon": epsilon, "begin_norm_axis": begin_norm_axis} + + # create output + helper = LayerHelper('layer_norm', **locals()) + mean_out = helper.create_variable_for_type_inference( + dtype=x.type, stop_gradient=True) + variance_out = helper.create_variable_for_type_inference( + dtype=x.type, stop_gradient=True) + layer_norm_out = helper.create_variable_for_type_inference(x.type) + + helper.append_op( + type="layer_norm", + inputs=inputs, + outputs={ + "Y": layer_norm_out, + "Mean": mean_out, + "Variance": variance_out, + }, + attrs={"epsilon": epsilon, + "begin_norm_axis": begin_norm_axis}) + + return helper.append_activation(layer_norm_out) + + +def instance_norm(x, + running_mean=None, + running_var=None, + weight=None, + bias=None, + use_input_stats=True, + momentum=0.9, + eps=1e-05, + data_format="NCHW", + name=None): + """ + See more detail in nn.layer.InstanceNorm2d. + + Parameters: + x(Tensor): Input Tensor. It's data type should be float32, float64. + running_mean(Tensor): running mean. Default None. + running_var(Tensor): running variance. Default None. + weight(Tensor, optional): The weight tensor of instance_norm. Default: None. + bias(Tensor, optional): The bias tensor of instance_norm. Default: None. + eps(float, optional): A value added to the denominator for numerical stability. Default is 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + use_input_stats(bool): Default True. + data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW" or "NCDHW". Defalut "NCHW". + name(str, optional): Name for the InstanceNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Returns: + None. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + np.random.seed(123) + x_data = np.random.random(size=(2, 2, 2, 3)).astype('float32') + x = paddle.to_tensor(x_data) + instance_norm_out = paddle.nn.functional.instancenorm(x) + + print(instance_norm_out.numpy) + + """ + + if in_dygraph_mode(): + out, _, _ = core.ops.instance_norm(x, weight, bias, "epsilon", eps, + "momentum", momentum, "data_format", + data_format) + return out + + check_variable_and_dtype(x, 'input', ['float32', 'float64'], "InstanceNorm") + + attrs = {"epsilon": eps, "momentum": momentum, "data_format": data_format} + + if weight and bias: + inputs = {"X": [x], "Scale": [weight], "Bias": [bias]} + else: + inputs = {"X": [x]} + + helper = LayerHelper('instance_norm', **locals()) + saved_mean = helper.create_variable_for_type_inference( + dtype=x.dtype, stop_gradient=True) + saved_variance = helper.create_variable_for_type_inference( + dtype=x.dtype, stop_gradient=True) + instance_norm_out = helper.create_variable_for_type_inference(x.dtype) + + outputs = { + "Y": [instance_norm_out], + "SavedMean": [saved_mean], + "SavedVariance": [saved_variance] + } + + helper.append_op( + type="instance_norm", inputs=inputs, outputs=outputs, attrs=attrs) + return instance_norm_out diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 369d462a8089a..c7855b23bf6e6 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -1,4 +1,17 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,28 +27,877 @@ # TODO: define normalization api -import warnings from ...fluid.dygraph.nn import InstanceNorm from ...fluid.dygraph import BatchNorm #DEFINE_ALIAS -from ...fluid.dygraph import GroupNorm #DEFINE_ALIAS -from ...fluid.dygraph import LayerNorm #DEFINE_ALIAS +#from ...fluid.dygraph import GroupNorm #DEFINE_ALIAS + +#from ...fluid.dygraph import LayerNorm #DEFINE_ALIAS from ...fluid.dygraph import SpectralNorm #DEFINE_ALIAS from ...fluid.dygraph import layers + +from ...framework import get_default_dtype, set_default_dtype from ...fluid.framework import in_dygraph_mode from ...fluid.initializer import Constant from ...fluid.param_attr import ParamAttr from ...fluid.data_feeder import check_variable_and_dtype, check_type -from ...fluid import core +from ...fluid import core, dygraph_utils + +from ..functional import batch_norm, layer_norm, instance_norm + +import numpy as np +import numbers +import warnings __all__ = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm', - 'SyncBatchNorm' + 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d', + 'InstanceNorm2d', 'InstanceNorm3d', 'SyncBatchNorm' ] +class _InstanceNormBase(layers.Layer): + """ + This class is based class for InstanceNorm1d, 2d, 3d. + + See InstaceNorm1d, InstanceNorm2d or InstanceNorm3d for more details. + """ + + def __init__(self, + num_features, + epsilon=1e-5, + momentum=0.9, + weight_attr=None, + bias_attr=None, + track_running_stats=False, + data_format="NCHW", + name=None): + super(_InstanceNormBase, self).__init__() + + if weight_attr == False or bias_attr == False: + assert weight_attr == param_attr, "weight_attr and bias_attr must be set to Fasle at the same time in InstanceNorm" + self._epsilon = epsilon + self._weight_attr = weight_attr + self._bias_attr = bias_attr + + if weight_attr != False and bias_attr != False: + self.scale = self.create_parameter( + attr=self._weight_attr, + shape=[num_features], + default_initializer=Constant(1.0), + is_bias=False) + self.bias = self.create_parameter( + attr=self._bias_attr, + shape=[num_features], + default_initializer=Constant(0.0), + is_bias=True) + else: + self.scale = None + self.bias = None + + def _check_input_dim(self, input): + raise NotImplementedError("InstanceNorm Base error") + + def forward(self, input): + self._check_input_dim(input) + + return instance_norm( + input, weight=self.scale, bias=self.bias, eps=self._epsilon) + + +class InstanceNorm1d(_InstanceNormBase): + """ + Applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with additional channel dimension) as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization . + + DataLayout: NCL `[batch, in_channels, length]` + + :math:`input` is the input features over a mini-batch. + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW} x_i \\qquad &//\\ + \\ mean\ of\ one\ feature\ map\ in\ mini-batch \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ variance\ of\ one\ feature\ map\ in\ mini-batch \\\\ + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + Note: + `H` means height of feature map, `W` means width of feature map. + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + epsilon(float, optional): A value added to the denominator for + numerical stability. Default is 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + track_running_stats(bool, optional): Whether to use global mean and + variance. In train mode, when setting track_running_stats True, the global mean + and variance are also used during train period. Default: False. + weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` + of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as weight_attr, the name of scale can be set in ParamAttr. + If the Initializer of the weight_attr is not set, the parameter is initialized + one. If it is set to False, will not create weight_attr. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of instance_norm. + If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. + If the Initializer of the bias_attr is not set, the bias is initialized zero. + If it is set to False, will not create bias_attr. Default: None. + data_format(str, optional): Specify the input data format, may be "NC", "NCL". Defalut "NCL". + name(str, optional): Name for the InstanceNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + + Shape: + - x: 2-D or 3-D tensor with shape: (batch, num_features) or (batch, num_features, length). + - output: 3-D tensor with same shape as input x. + + Returns: + None. + + **Note**: + Momentum and track_running_stats is not effective. The next version will fix the problem . + + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + np.random.seed(123) + x_data = np.random.random(size=(2, 2, 3)).astype('float32') + x = paddle.to_tensor(x_data) + instance_norm = paddle.nn.InstanceNorm1d(2) + instance_norm_out = instance_norm(x) + + print(instance_norm_out.numpy) + + """ + + def _check_input_dim(self, input): + if len(input.shape) != 2 and len(input.shape) != 3: + raise ValueError('expected 2D or 3D input (got {}D input)'.format( + len(input.shape))) + + +class InstanceNorm2d(_InstanceNormBase): + """ + Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization . + + DataLayout: NCHW `[batch, in_channels, in_height, in_width]` + + + :math:`input` is the input features over a mini-batch. + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW} x_i \\qquad &//\\ + \\ mean\ of\ one\ feature\ map\ in\ mini-batch \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ variance\ of\ one\ feature\ map\ in\ mini-batch \\\\ + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + Note: + `H` means height of feature map, `W` means width of feature map. + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + epsilon(float, optional): A value added to the denominator for + numerical stability. Default is 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + track_running_stats(bool, optional): Whether to use global mean and + variance. In train mode, when setting track_running_stats True, the global mean + and variance are also used during train period. Default: False. + weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` + of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as weight_attr, the name of scale can be set in ParamAttr. + If the Initializer of the weight_attr is not set, the parameter is initialized + one. If it is set to False, will not create weight_attr. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of instance_norm. + If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. + If the Initializer of the bias_attr is not set, the bias is initialized zero. + If it is set to False, will not create bias_attr. Default: None. + data_format(str, optional): Specify the input data format, could be "NCHW". Default: NCHW. + name(str, optional): Name for the InstanceNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Shape: + - x: 4-D tensor with shape: (batch, num_features, height, weight). + - output: 4-D tensor with same shape as input x. + + Returns: + None. + + **Note**: + Momentum and track_running_stats is not effective. The next version will fix the problem . + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + np.random.seed(123) + x_data = np.random.random(size=(2, 2, 2, 3)).astype('float32') + x = paddle.to_tensor(x_data) + instance_norm = paddle.nn.InstanceNorm2d(2) + instance_norm_out = instance_norm(x) + + print(instance_norm_out.numpy) + """ + + def _check_input_dim(self, input): + if len(input.shape) != 4: + raise ValueError('expected 4D input (got {}D input)'.format( + len(input.shape))) + + +class InstanceNorm3d(_InstanceNormBase): + """ + Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization . + + DataLayout: NCHW `[batch, in_channels, D, in_height, in_width]` + + + :math:`input` is the input features over a mini-batch. + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW} x_i \\qquad &//\\ + \\ mean\ of\ one\ feature\ map\ in\ mini-batch \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ variance\ of\ one\ feature\ map\ in\ mini-batch \\\\ + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + Note: + `H` means height of feature map, `W` means width of feature map. + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + epsilon(float, optional): A value added to the denominator for + numerical stability. Default is 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + track_running_stats(bool, optional): Whether to use global mean and + variance. In train mode, when setting track_running_stats True, the global mean + and variance are also used during train period. Default: False. + weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` + of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as weight_attr, the name of scale can be set in ParamAttr. + If the Initializer of the weight_attr is not set, the parameter is initialized + one. If it is set to False, will not create weight_attr. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of instance_norm. + If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. + If the Initializer of the bias_attr is not set, the bias is initialized zero. + If it is set to False, will not create bias_attr. Default: None. + data_format(str, optional): Specify the input data format, could be "NCDHW". Default: NCDHW. + name(str, optional): Name for the InstanceNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Shape: + - x: 5-D tensor with shape: (batch, num_features, dims, height, weight). + - output: 5-D tensor with same shape as input x. + + Returns: + None. + + **Note**: + Momentum and track_running_stats is not effective. The next version will fix the problem . + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + np.random.seed(123) + x_data = np.random.random(size=(2, 2, 2, 2, 3)).astype('float32') + x = paddle.to_tensor(x_data) + instance_norm = paddle.nn.InstanceNorm3d(2) + instance_norm_out = instance_norm(x) + + print(instance_norm_out.numpy) + """ + + def _check_input_dim(self, input): + if len(input.shape) != 5: + raise ValueError('expected 5D input (got {}D input)'.format( + len(input.shape))) + + +class GroupNorm(layers.Layer): + """ + This interface is used to construct a callable object of the ``GroupNorm`` class. + For more details, refer to code examples. + It implements the function of the Group Normalization Layer. + Refer to `Group Normalization `_ . + + Parameters: + num_channels(int): The number of channels of input. + num_groups(int): The number of groups that divided from channels. + epsilon(float, optional): The small value added to the variance to prevent + division by zero. Default: 1e-05. + weight_attr(ParamAttr|bool, optional): The parameter attribute for the learnable + scale :math:`g`. If it is set to False, no scale will be added to the output units. + If it is set to None, the bias is initialized one. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the learnable + bias :math:`b`. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. Default: None. + data_format(str, optional): Specify the input data format. Only NCHW is supported. Default: NCHW. + name(str, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Shape: + - x: 4-D tensor with shape: (batch, num_features, height, weight). + - output: 4-D tensor with same shape as input x. + + Returns: + None + + Examples: + .. code-block:: python + import paddle + import numpy as np + + paddle.disable_static() + np.random.seed(123) + x_data = np.random.random(size=(2, 6, 2, 2)).astype('float32') + x = paddle.to_tensor(x_data) + group_norm = paddle.nn.GroupNorm(num_channels=3, num_groups=6) + group_norm_out = group_norm(x) + + print(group_norm_out.numpy) + """ + + def __init__(self, + num_channels, + num_groups, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + data_layout='NCHW', + name=None): + super(GroupNorm, self).__init__() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self._epsilon = epsilon + self._num_channels = num_channels + self._num_groups = num_groups + if data_layout != 'NCHW': + raise ValueError("unsupported data layout:" + data_layout) + + param_shape = [self._num_channels] + + self.weight = self.create_parameter( + attr=self._weight_attr or False, + shape=param_shape, + default_initializer=Constant(1.0)) + + self.bias = self.create_parameter( + attr=self._weight_attr or False, shape=param_shape, is_bias=True) + + def forward(self, input): + inputs = {'X': input} + if self.bias is not None: + inputs['Bias'] = self.bias + if self.weight is not None: + inputs['Scale'] = self.weight + + # create output + mean_out = self._helper.create_variable_for_type_inference( + dtype=input.dtype, stop_gradient=True) + variance_out = self._helper.create_variable_for_type_inference( + dtype=input.dtype, stop_gradient=True) + group_norm_out = self._helper.create_variable_for_type_inference( + dtype=input.dtype) + + self._helper.append_op( + type="group_norm", + inputs=inputs, + outputs={ + "Y": group_norm_out, + "Mean": mean_out, + "Variance": variance_out, + }, + attrs={"epsilon": self._epsilon, + "groups": self._num_groups}) + + return self._helper.append_activation(group_norm_out, None) + + +class LayerNorm(layers.Layer): + """ + :alias_main: paddle.nn.LayerNorm + :alias: paddle.nn.LayerNorm,paddle.nn.layer.LayerNorm,paddle.nn.layer.norm.LayerNorm + :old_api: paddle.fluid.dygraph.LayerNorm + + This interface is used to construct a callable object of the ``LayerNorm`` class. + For more details, refer to code examples. + It implements the function of the Layer Normalization Layer and can be applied to mini-batch input data. + Refer to `Layer Normalization `_ + + The formula is as follows: + + .. math:: + + \\mu & = \\frac{1}{H}\\sum_{i=1}^{H} x_i + + \\sigma & = \\sqrt{\\frac{1}{H}\sum_{i=1}^{H}{(x_i - \\mu)^2} + \\epsilon} + + y & = f(\\frac{g}{\\sigma}(x - \\mu) + b) + + - :math:`x`: the vector representation of the summed inputs to the neurons in that layer. + - :math:`H`: the number of hidden units in a layers + - :math:`\\epsilon`: the small value added to the variance to prevent division by zero. + - :math:`g`: the trainable scale parameter. + - :math:`b`: the trainable bias parameter. + + Parameters: + normalized_shape(int|list|tuple): Input shape from an expected input of + size :math:`[*, normalized_shape[0], normalized_shape[1], ..., normalized_shape[-1]]`. + If it is a single integer, this module will normalize over the last dimension + which is expected to be of that specific size. + epsilon(float, optional): The small value added to the variance to prevent + division by zero. Default: 1e-05. + weight_attr(ParamAttr|bool, optional): The parameter attribute for the learnable + gain :math:`g`. If False, weight is None. If is None, a default :code:`ParamAttr` would be added as scale. The + :attr:`param_attr` is initialized as 1 if it is added. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the learnable + bias :math:`b`. If is False, bias is None. If is None, a default :code:`ParamAttr` would be added as bias. The + :attr:`bias_attr` is initialized as 0 if it is added. Default: None. + name(str, optional): Name for the LayerNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Shape: + - x: 2-D, 3-D, 4-D or 5-D tensor. + - output: same shape as input x. + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + np.random.seed(123) + x_data = np.random.random(size=(2, 2, 2, 3)).astype('float32') + x = paddle.to_tensor(x_data) + layer_norm = paddle.nn.LayerNorm(x_data.shape[1:]) + layer_norm_out = layer_norm(x) + + print(layer_norm_out.numpy) + """ + + def __init__(self, + normalized_shape, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + name=None): + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = [normalized_shape] + + self._normalized_shape = list(normalized_shape) + self._epsilon = epsilon + self._weight_attr = weight_attr + self._bias_attr = bias_attr + param_shape = [np.prod(self._normalized_shape)] + + if weight_attr is False: + self.weight = None + else: + self.weight = self.create_parameter( + attr=self._weight_attr, + shape=param_shape, + default_initializer=Constant(1.0)) + + if bias_attr is False: + self.bias = None + else: + self.bias = self.create_parameter( + attr=self._bias_attr, shape=param_shape, is_bias=True) + + def forward(self, input): + return layer_norm( + input, + normalized_shape=self._normalized_shape, + weight=self.weight, + bias=self.bias, + epsilon=self._epsilon) + + +class _BatchNormBase(layers.Layer): + """ + BatchNorm base . + """ + + def __init__(self, + num_features, + momentum=0.9, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + data_format='NCHW', + track_running_stats=True, + name=None): + super(_BatchNormBase, self).__init__() + self._num_features = num_features + self._weight_attr = weight_attr + self._bias_attr = bias_attr + + if get_default_dtype() == 'float16': + set_default_dtype('float32') + + param_shape = [num_features] + + # create parameter + self.weight = self.create_parameter( + attr=self._weight_attr, + shape=param_shape, + default_initializer=Constant(1.0)) + self.weight.stop_gradient = (self._weight_attr is False) or ( + self._weight_attr and self._weight_attr.learning_rate == 0.) + + self.bias = self.create_parameter( + attr=self._bias_attr, shape=param_shape, is_bias=True) + self.bias.stop_gradient = (self._bias_attr is False) or ( + self._bias_attr and self._bias_attr.learning_rate == 0.) + + moving_mean_name = None + moving_variance_name = None + + if name is not None: + moving_mean_name = name + "_mean" + moving_variance_name = name + "_variance" + + self._mean = self.create_parameter( + attr=ParamAttr( + name=moving_mean_name, + initializer=Constant(0.0), + trainable=False, + do_model_average=True), + shape=param_shape, + dtype=self._dtype) + self._mean.stop_gradient = True + + self._variance = self.create_parameter( + attr=ParamAttr( + name=moving_variance_name, + initializer=Constant(1.0), + trainable=False, + do_model_average=True), + shape=param_shape, + dtype=self._dtype) + self._variance.stop_gradient = True + + self._data_format = data_format + self._in_place = False + self._momentum = momentum + self._epsilon = epsilon + self._fuse_with_relu = False + self._track_running_stats = track_running_stats + + def _check_input_dim(self, input): + raise NotImplementedError("BatchNorm Base error") + + def forward(self, input): + + self._check_input_dim(input) + + if not self.training and not self._track_running_stats: + raise ValueError( + 'When inference, expected track_running_stats is True.') + + if self.training and not self._track_running_stats: + warnings.warn( + "When training, we now always track global mean and variance.") + + return batch_norm( + input, + self._mean, + self._variance, + weight=self.weight, + bias=self.bias, + training=self.training, + momentum=self._momentum, + epsilon=self._epsilon, + data_format=self._data_format) + + +class BatchNorm1d(_BatchNormBase): + """ + Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D inputswith additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift . + + When track_running_stats = False, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are the statistics of one mini-batch. + Calculated as follows: + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ + \ mini-batch\ mean \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ + + When track_running_stats = True, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch. + They are global or running statistics (moving_mean and moving_variance). It usually got from the + pre-trained model. Calculated as follows: + + .. math:: + moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\ + moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\ + + The normalization function formula is as follows: + + .. math:: + + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + - :math:`\\epsilon` : add a smaller value to the variance to prevent division by zero + - :math:`\\gamma` : trainable proportional parameter + - :math:`\\beta` : trainable deviation parameter + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` + of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as weight_attr. If it is set to Fasle, the weight is not learnable. + If the Initializer of the weight_attr is not set, the parameter is initialized with Xavier. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of batch_norm. + If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. + If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. + data_format(str, optional): Specify the input data format, may be "NC", "NCL". Defalut "NCL". + track_running_stats(bool, optional): Whether to use global mean and variance. In train period, + True will track global mean and variance used for inference. When inference, track_running_stats must be + True. Default: True. + name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Shape: + - x: 2-D or 3-D tensor with shape: (batch, num_features) or (batch, num_features, length). + - output: 3-D tensor with same shape as input x. + + Returns: + None. + + **Note**: + Now track_running_stats is actucal always true. The next version will fix the problem . + + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + np.random.seed(123) + x_data = np.random.random(size=(2, 1, 3)).astype('float32') + x = paddle.to_tensor(x_data) + batch_norm = paddle.nn.BatchNorm1d(1) + batch_norm_out = batch_norm(x) + + print(batch_norm_out.numpy) + """ + + def _check_input_dim(self, input): + if len(input.shape) != 2 and len(input.shape) != 3: + raise ValueError('expected 2D or 3D input (got {}D input)'.format( + len(input.shape))) + + +class BatchNorm2d(_BatchNormBase): + """ + Applies Batch Normalization over a 4D input (a mini-batch of 2D inputswith additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift . + + When track_running_stats = False, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are the statistics of one mini-batch. + Calculated as follows: + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ + \ mini-batch\ mean \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ + + When track_running_stats = True, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch. + They are global or running statistics (moving_mean and moving_variance). It usually got from the + pre-trained model. Calculated as follows: + + .. math:: + moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\ + moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\ + + The normalization function formula is as follows: + + .. math:: + + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + - :math:`\\epsilon` : add a smaller value to the variance to prevent division by zero + - :math:`\\gamma` : trainable proportional parameter + - :math:`\\beta` : trainable deviation parameter + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` + of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as weight_attr. If it is set to Fasle, the weight is not learnable. + If the Initializer of the weight_attr is not set, the parameter is initialized with Xavier. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of batch_norm. + If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. + If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. + data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. + track_running_stats(bool, optional): Whether to use global mean and variance. In train period, + True will track global mean and variance used for inference. When inference, track_running_stats must be + True. Default: True. + name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Shape: + - x: 4-D tensor with shape: (batch, num_features, height, weight). + - output: 4-D tensor with same shape as input x. + + Returns: + None + + **Note**: + Now track_running_stats is actucal always true. The next version will fix the problem . + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + np.random.seed(123) + x_data = np.random.random(size=(2, 1, 2, 3)).astype('float32') + x = paddle.to_tensor(x_data) + batch_norm = paddle.nn.BatchNorm2d(1) + batch_norm_out = batch_norm(x) + + print(batch_norm_out.numpy) + """ + + def _check_input_dim(self, input): + if len(input.shape) != 4: + raise ValueError('expected 4D input (got {}D input)'.format( + len(input.shape))) + + +class BatchNorm3d(_BatchNormBase): + """ + Applies Batch Normalization over a 5D input (a mini-batch of 3D inputswith additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift . + + When track_running_stats = False, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are the statistics of one mini-batch. + Calculated as follows: + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ + \ mini-batch\ mean \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ + + When track_running_stats = True, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch. + They are global or running statistics (moving_mean and moving_variance). It usually got from the + pre-trained model. Calculated as follows: + + .. math:: + moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\ + moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\ + + The normalization function formula is as follows: + + .. math:: + + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + - :math:`\\epsilon` : add a smaller value to the variance to prevent division by zero + - :math:`\\gamma` : trainable proportional parameter + - :math:`\\beta` : trainable deviation parameter + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` + of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as weight_attr. If it is set to Fasle, the weight is not learnable. + If the Initializer of the weight_attr is not set, the parameter is initialized with Xavier. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of batch_norm. + If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. + If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. + data_format(str, optional): Specify the input data format, the data format can be "NCDHW". Default: NCDHW. + track_running_stats(bool, optional): Whether to use global mean and variance. In train period, + True will track global mean and variance used for inference. When inference, track_running_stats must be + True. Default: True. + name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Shape: + - x: 5-D tensor with shape: (batch, num_features, dims, height, weight). + - output: 5-D tensor with same shape as input x. + + Returns: + None + + **Note**: + Now track_running_stats is actucal always true. The next version will fix the problem . + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + np.random.seed(123) + x_data = np.random.random(size=(2, 1, 2, 2, 3)).astype('float32') + x = paddle.to_tensor(x_data) + batch_norm = paddle.nn.BatchNorm3d(1) + batch_norm_out = batch_norm(x) + + print(batch_norm_out.numpy) + """ + + def _check_input_dim(self, input): + if len(input.shape) != 5: + raise ValueError('expected 5D input (got {}D input)'.format( + len(input.shape))) + + class SyncBatchNorm(layers.Layer): """ This interface is used to construct a callable object of the ``SyncBatchNorm`` class.