diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 77436e9293d64..50baf6831dc8a 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -49,9 +49,7 @@ import paddle.utils.deprecated as deprecated from paddle import _C_ops, _legacy_C_ops -__all__ = [ - 'BatchNorm', -] +__all__ = [] class BatchNorm(layers.Layer): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/darknet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/darknet.py index 9199d0c2d96b2..783dfff262e8f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/darknet.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/darknet.py @@ -14,9 +14,9 @@ import paddle import paddle.fluid as fluid -from paddle.fluid.dygraph.nn import BatchNorm from paddle.fluid.param_attr import ParamAttr from paddle.fluid.regularizer import L2Decay +from paddle.nn import BatchNorm class ConvBNLayer(fluid.dygraph.Layer): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py index 5121de5e6a03b..608baa74ec1e0 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py @@ -31,15 +31,16 @@ import numpy as np from PIL import Image, ImageOps +import paddle.fluid as fluid + # Use GPU:0 to elimate the influence of other tasks. os.environ["CUDA_VISIBLE_DEVICES"] = "1" import paddle -import paddle.fluid as fluid from paddle.fluid.dygraph import to_variable -from paddle.fluid.dygraph.nn import BatchNorm from paddle.jit import ProgramTranslator from paddle.jit.api import declarative +from paddle.nn import BatchNorm # Note: Set True to eliminate randomness. # 1. For one operation, cuDNN has several algorithms, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mobile_net.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mobile_net.py index d464c4b1d13a7..d5a4ae996d68b 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mobile_net.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_mobile_net.py @@ -23,12 +23,11 @@ import paddle import paddle.fluid as fluid from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX -from paddle.fluid.dygraph.nn import BatchNorm from paddle.fluid.initializer import MSRA from paddle.fluid.param_attr import ParamAttr from paddle.jit import ProgramTranslator from paddle.jit.api import declarative -from paddle.nn import Linear +from paddle.nn import BatchNorm, Linear # Note: Set True to eliminate randomness. # 1. For one operation, cuDNN has several algorithms, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py index f59dee857a823..cacf9b40d5cc0 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py @@ -24,8 +24,8 @@ import paddle import paddle.fluid as fluid from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX -from paddle.fluid.dygraph.nn import BatchNorm from paddle.jit import ProgramTranslator +from paddle.nn import BatchNorm SEED = 2020 IMAGENET1000 = 1281167 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_se_resnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_se_resnet.py index 34ae0cd19a7ee..fe6437936498b 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_se_resnet.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_se_resnet.py @@ -26,10 +26,9 @@ import paddle.fluid as fluid from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX -from paddle.fluid.dygraph.nn import BatchNorm from paddle.jit import ProgramTranslator from paddle.jit.api import declarative -from paddle.nn import Linear +from paddle.nn import BatchNorm, Linear SEED = 2020 np.random.seed(SEED) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tsm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tsm.py index b66a3ed21f47a..805e42a03a10e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tsm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tsm.py @@ -24,10 +24,9 @@ import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import to_variable -from paddle.fluid.dygraph.nn import BatchNorm from paddle.jit import ProgramTranslator from paddle.jit.api import declarative -from paddle.nn import Linear +from paddle.nn import BatchNorm, Linear random.seed(0) np.random.seed(0) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_batch_norm_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_batch_norm_op_mlu.py index 29be16759e9c2..66876ddb79294 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_batch_norm_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_batch_norm_op_mlu.py @@ -18,7 +18,6 @@ import paddle import paddle.fluid.core as core from paddle.fluid.op import Operator -import paddle.fluid as fluid import sys sys.path.append('..') @@ -753,7 +752,7 @@ def test_errors(self): class TestDygraphBatchNormAPIError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): - batch_norm = fluid.dygraph.BatchNorm(10) + batch_norm = paddle.nn.BatchNorm(10) # the input of BatchNorm must be Variable. x1 = fluid.create_lod_tensor( np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace() @@ -776,7 +775,7 @@ def test_dygraph(self): def compute(x, is_test, trainable_statistics): with fluid.dygraph.guard(p): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, trainable_statistics=trainable_statistics, @@ -799,7 +798,7 @@ def test_static(self): def compute(x_np, is_test, trainable_statistics): with program_guard(Program(), Program()): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, trainable_statistics=trainable_statistics, @@ -824,7 +823,7 @@ def test_reservespace(self): x = fluid.data(name='x', shape=x.shape, dtype=x.dtype) # Set this FLAG, the BatchNorm API will pass "reserve_space" argument into batch_norm op. os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1' - batch_norm = fluid.dygraph.BatchNorm(7, data_layout="NHWC") + batch_norm = paddle.nn.BatchNorm(7, data_layout="NHWC") hidden1 = batch_norm(x) os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '0' diff --git a/python/paddle/fluid/tests/unittests/mlu/test_batch_norm_op_mlu_v2.py b/python/paddle/fluid/tests/unittests/mlu/test_batch_norm_op_mlu_v2.py index 72e7ac89caf36..17672d668d38a 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_batch_norm_op_mlu_v2.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_batch_norm_op_mlu_v2.py @@ -17,7 +17,6 @@ import numpy as np import paddle.fluid.core as core from paddle.fluid.op import Operator -import paddle.fluid as fluid import sys sys.path.append("..") @@ -95,7 +94,7 @@ def test_dygraph(self): def compute_v1(x, is_test, trainable_statistics): with fluid.dygraph.guard(p): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, trainable_statistics=trainable_statistics, @@ -111,7 +110,7 @@ def compute_v2(x): def compute_v3(x, is_test, trainable_statistics): with fluid.dygraph.guard(p): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, param_attr=fluid.ParamAttr( @@ -153,7 +152,7 @@ def test_static(self): def compute_v1(x_np, is_test, trainable_statistics): with program_guard(Program(), Program()): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, trainable_statistics=trainable_statistics, @@ -260,7 +259,7 @@ def test_global_stats(self): for p in self.places: with fluid.dygraph.guard(p): x = paddle.randn([2, 6, 6, 4]) - net1 = paddle.fluid.dygraph.BatchNorm( + net1 = paddle.nn.BatchNorm( 6, param_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(1.0) diff --git a/python/paddle/fluid/tests/unittests/npu/test_batch_norm_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_batch_norm_op_npu.py index e39506eed7a9b..353fd250a5e1b 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_batch_norm_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_batch_norm_op_npu.py @@ -562,7 +562,7 @@ def test_dygraph(self): def compute(x, is_test, trainable_statistics): with fluid.dygraph.guard(p): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, trainable_statistics=trainable_statistics, @@ -583,7 +583,7 @@ def test_static(self): def compute(x_np, is_test, trainable_statistics): with program_guard(Program(), Program()): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, trainable_statistics=trainable_statistics, diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 079628658addb..6802a8a9ea995 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -770,7 +770,7 @@ def test_errors(self): class TestDygraphBatchNormAPIError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): - batch_norm = fluid.dygraph.BatchNorm(10) + batch_norm = paddle.nn.BatchNorm(10) # the input of BatchNorm must be Variable. x1 = fluid.create_lod_tensor( np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace() @@ -793,7 +793,7 @@ def test_dygraph(self): def compute(x, is_test, trainable_statistics): with fluid.dygraph.guard(p): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, trainable_statistics=trainable_statistics, @@ -816,7 +816,7 @@ def test_static(self): def compute(x_np, is_test, trainable_statistics): with program_guard(Program(), Program()): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, trainable_statistics=trainable_statistics, @@ -841,7 +841,7 @@ def test_reservespace(self): x = fluid.data(name='x', shape=x.shape, dtype=x.dtype) # Set this FLAG, the BatchNorm API will pass "reserve_space" argument into batch_norm op. os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1' - batch_norm = fluid.dygraph.BatchNorm(7, data_layout="NHWC") + batch_norm = paddle.nn.BatchNorm(7, data_layout="NHWC") hidden1 = batch_norm(x) os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '0' diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py index 1818d392057e6..74edcd61d343e 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py @@ -82,7 +82,7 @@ def error3d(): def test_large_batch(self): def compute_baseline(x): with fluid.dygraph.guard(p): - bn = fluid.dygraph.BatchNorm(shape[1]) + bn = paddle.nn.BatchNorm(shape[1]) x1 = paddle.to_tensor(x) x1.stop_gradient = False y = bn(x1) @@ -128,7 +128,7 @@ def test_eager_api(self): def compute_v1(x): with fluid.dygraph.guard(p): - bn = fluid.dygraph.BatchNorm(shape[1]) + bn = paddle.nn.BatchNorm(shape[1]) # bn = paddle.nn.BatchNorm2D(shape[1]) x1 = paddle.to_tensor(x) x1.stop_gradient = False @@ -162,7 +162,7 @@ def test_dygraph(self): def compute_v1(x, is_test, trainable_statistics): with fluid.dygraph.guard(p): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, trainable_statistics=trainable_statistics, @@ -183,7 +183,7 @@ def compute_v2(x): def compute_v3(x, is_test, trainable_statistics): with fluid.dygraph.guard(p): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, param_attr=fluid.ParamAttr( @@ -225,7 +225,7 @@ def test_static(self): def compute_v1(x_np, is_test, trainable_statistics): with program_guard(Program(), Program()): - bn = fluid.dygraph.BatchNorm( + bn = paddle.nn.BatchNorm( shape[1], is_test=is_test, trainable_statistics=trainable_statistics, @@ -379,7 +379,7 @@ def test_global_stats(self): for p in self.places: with fluid.dygraph.guard(p): x = paddle.randn([2, 6, 6, 4]) - net1 = paddle.fluid.dygraph.BatchNorm( + net1 = paddle.nn.BatchNorm( 6, param_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(1.0) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py b/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py index 2003e685327b8..adc1b37770c7d 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py @@ -21,8 +21,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.framework as framework -from paddle.fluid.dygraph.nn import BatchNorm -from paddle.nn import Linear +from paddle.nn import BatchNorm, Linear class TestDygraphLoadStatic(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py b/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py index 12118beaffe3b..d64b48ee46080 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py @@ -21,9 +21,8 @@ import paddle.fluid as fluid from paddle.fluid import core from paddle.fluid.dygraph.base import to_variable -from paddle.fluid.dygraph.nn import BatchNorm from paddle.fluid.framework import _test_eager_guard -from paddle.nn import Linear +from paddle.nn import BatchNorm, Linear class Config: diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py index 16951a8743c4f..fcff2fc7268aa 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -20,10 +20,11 @@ import paddle import paddle.fluid as fluid -from paddle.fluid import BatchNorm, core +from paddle.fluid import core from paddle.fluid.dygraph.base import to_variable from paddle.fluid.framework import _test_eager_guard from paddle.fluid.layer_helper import LayerHelper +from paddle.nn import BatchNorm # NOTE(zhiqiu): run with FLAGS_cudnn_deterministic=1 diff --git a/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py b/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py index 6180d1c66494b..5908555ea7052 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py @@ -20,12 +20,9 @@ import paddle import paddle.fluid as fluid from paddle.fluid import core -from paddle.fluid.dygraph.nn import BatchNorm from paddle.fluid.framework import _test_eager_guard from paddle.fluid.layer_helper import LayerHelper - -if fluid.is_compiled_with_cuda(): - fluid.set_flags({'FLAGS_cudnn_deterministic': True}) +from paddle.nn import BatchNorm batch_size = 8 train_parameters = { @@ -120,7 +117,6 @@ def __init__(self, num_channels, reduction_ratio): initializer=paddle.nn.initializer.Constant(value=0.05) ), ) - self.act_2 = paddle.nn.Softmax() def forward(self, input): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py index 7c8f6d103467b..d3909193cd6ce 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py @@ -366,7 +366,7 @@ def test_global_stats(self): for p in self.places: with fluid.dygraph.guard(p): x = paddle.randn([2, 6, 6, 4]) - net1 = paddle.fluid.dygraph.BatchNorm( + net1 = paddle.nn.BatchNorm( 6, param_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(1.0) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_fused_resnet_basic_block_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_fused_resnet_basic_block_op_xpu.py index 68bf21abd1a06..3518083d75678 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_fused_resnet_basic_block_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_fused_resnet_basic_block_op_xpu.py @@ -113,7 +113,7 @@ def Base(self): bias_attr=None, data_format='NCHW', ) - self.bn1 = nn.BatchNorm( + self.bn1 = paddle.nn.BatchNorm( self.out_channels, act='relu', param_attr=bn1_weight, @@ -130,7 +130,7 @@ def Base(self): bias_attr=None, data_format='NCHW', ) - self.bn2 = nn.BatchNorm( + self.bn2 = paddle.nn.BatchNorm( self.out_channels, act=None, param_attr=bn2_weight, @@ -147,7 +147,7 @@ def Base(self): bias_attr=None, data_format='NCHW', ) - self.bn3 = nn.BatchNorm( + self.bn3 = paddle.nn.BatchNorm( self.out_channels, act=None, param_attr=bn3_weight, diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 1c7e64d794a65..f446970ee0ef6 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -37,9 +37,15 @@ from paddle.device import get_all_custom_device_type from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode +from ...fluid import dygraph_utils from ...fluid.data_feeder import check_variable_and_dtype -from ...fluid.dygraph import BatchNorm # noqa: F401 -from ...framework import ParamAttr, get_default_dtype, no_grad +from ...framework import ( + ParamAttr, + _global_flags, + _non_static_mode, + get_default_dtype, + no_grad, +) from .. import Layer from .. import functional as F from ..functional import batch_norm, instance_norm, layer_norm @@ -752,6 +758,312 @@ def extra_repr(self): return main_str +class BatchNorm(Layer): + r""" + This interface is used to construct a callable object of the ``BatchNorm`` class. + For more details, refer to code examples. + It implements the function of the Batch Normalization Layer and can be used + as a normalizer function for conv2d and fully connected operations. + The data is normalized by the mean and variance of the channel based on the current batch data. + Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `_ + for more details. + + When use_global_stats = False, the :math:`\mu_{\beta}` + and :math:`\sigma_{\beta}^{2}` are the statistics of one mini-batch. + Calculated as follows: + + .. math:: + + \mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad & + //\ mini-batch\ mean \\ + \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \mu_{\beta})^2 \qquad & + //\ mini-batch\ variance \\ + + - :math:`x` : mini-batch data + - :math:`m` : the size of the mini-batch data + + When use_global_stats = True, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch. + They are global or running statistics (moving_mean and moving_variance). It usually got from the + pre-trained model. Calculated as follows: + + .. math:: + moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\ + moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\ + + The normalization function formula is as follows: + + .. math:: + + \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\ + \sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\ + y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift + + + - :math:`\epsilon` : add a smaller value to the variance to prevent division by zero + - :math:`\gamma` : trainable proportional parameter + - :math:`\beta` : trainable deviation parameter + + Parameters: + num_channels(int): Indicate the number of channels of the input ``Tensor``. + act(str, optional): Activation to be applied to the output of batch normalization. Default: None. + is_test (bool, optional): A flag indicating whether it is in test phrase or not. + This flag only has effect on static graph mode. For dygraph mode, please use ``eval()``. + Default: False. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. + param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale` + of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. Default: None. + bias_attr(ParamAttr, optional): The parameter attribute for the bias of batch_norm. + If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + dtype(str, optional): Indicate the data type of the input ``Tensor``, + which can be float32 or float64. Default: float32. + data_layout(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. + in_place(bool, optional): Make the input and output of batch norm reuse memory. Default: False. + moving_mean_name(str, optional): The name of moving_mean which store the global Mean. Default: None. + moving_variance_name(str, optional): The name of the moving_variance which store the global Variance. Default: None. + do_model_average_for_mean_and_var(bool, optional): Whether parameter mean and variance should do model + average when model average is enabled. Default: True. + use_global_stats(bool, optional): Whether to use global mean and + variance. In inference or test mode, set use_global_stats to true + or is_test to true, and the behavior is equivalent. + In train mode, when setting use_global_stats True, the global mean + and variance are also used during train period. Default: False. + trainable_statistics(bool, optional): Whether to calculate mean and var in eval mode. In eval mode, when + setting trainable_statistics True, mean and variance will be calculated by current batch statistics. + Default: False. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.nn as nn + from paddle.fluid.dygraph.base import to_variable + import numpy as np + + + x = np.random.random(size=(3, 10, 3, 7)).astype('float32') + with fluid.dygraph.guard(): + x = to_variable(x) + batch_norm = nn.layer.norm.BatchNorm(10) + hidden1 = batch_norm(x) + """ + + def __init__( + self, + num_channels, + act=None, + is_test=False, + momentum=0.9, + epsilon=1e-05, + param_attr=None, + bias_attr=None, + dtype='float32', + data_layout='NCHW', + in_place=False, + moving_mean_name=None, + moving_variance_name=None, + do_model_average_for_mean_and_var=True, + use_global_stats=False, + trainable_statistics=False, + ): + super().__init__() + self._param_attr = param_attr + self._bias_attr = bias_attr + self._act = act + self._use_mkldnn = _global_flags()["FLAGS_use_mkldnn"] + + assert ( + bias_attr is not False + ), "bias_attr should not be False in batch_norm." + + if dtype == "float16": + self._dtype = "float32" + else: + self._dtype = dtype + + param_shape = [num_channels] + + # create parameter + self.weight = self.create_parameter( + attr=self._param_attr, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(1.0), + ) + self.weight.stop_gradient = ( + use_global_stats and self._param_attr.learning_rate == 0.0 + ) + + self.bias = self.create_parameter( + attr=self._bias_attr, + shape=param_shape, + dtype=self._dtype, + is_bias=True, + ) + self.bias.stop_gradient = ( + use_global_stats and self._param_attr.learning_rate == 0.0 + ) + + self._mean = self.create_parameter( + attr=ParamAttr( + name=moving_mean_name, + initializer=Constant(0.0), + trainable=False, + do_model_average=do_model_average_for_mean_and_var, + ), + shape=param_shape, + dtype=self._dtype, + ) + self._mean.stop_gradient = True + + self._variance = self.create_parameter( + attr=ParamAttr( + name=moving_variance_name, + initializer=Constant(1.0), + trainable=False, + do_model_average=do_model_average_for_mean_and_var, + ), + shape=param_shape, + dtype=self._dtype, + ) + self._variance.stop_gradient = True + + self._in_place = in_place + self._data_layout = data_layout + self._momentum = momentum + self._epsilon = epsilon + self._is_test = is_test + self._fuse_with_relu = False + self._use_global_stats = use_global_stats + self._trainable_statistics = trainable_statistics + + def forward(self, input): + # create output + # mean and mean_out share the same memory + mean_out = self._mean + # variance and variance out share the same memory + variance_out = self._variance + + if _non_static_mode(): + if in_dygraph_mode(): + batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm( + input, + self._mean, + self._variance, + self.weight, + self.bias, + not self.training, + self._momentum, + self._epsilon, + self._data_layout, + self._use_global_stats, + self._trainable_statistics, + ) + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn + ) + + elif _in_legacy_dygraph(): + attrs = ( + "momentum", + self._momentum, + "epsilon", + self._epsilon, + "is_test", + not self.training, + "data_layout", + self._data_layout, + "use_mkldnn", + self._use_mkldnn, + "fuse_with_relu", + self._fuse_with_relu, + "use_global_stats", + self._use_global_stats, + 'trainable_statistics', + self._trainable_statistics, + ) + batch_norm_out, _, _, _, _, _ = _legacy_C_ops.batch_norm( + input, + self.weight, + self.bias, + self._mean, + self._variance, + None, + mean_out, + variance_out, + *attrs + ) + + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn + ) + + check_variable_and_dtype( + input, 'input', ['float16', 'float32', 'float64'], 'BatchNorm' + ) + + attrs = { + "momentum": self._momentum, + "epsilon": self._epsilon, + "is_test": self._is_test, + "data_layout": self._data_layout, + "use_mkldnn": False, + "fuse_with_relu": self._fuse_with_relu, + "use_global_stats": self._use_global_stats, + "trainable_statistics": self._trainable_statistics, + } + + inputs = { + "X": [input], + "Scale": [self.weight], + "Bias": [self.bias], + "Mean": [self._mean], + "Variance": [self._variance], + } + + saved_mean = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True + ) + saved_variance = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True + ) + reserve_space = self._helper.create_variable_for_type_inference( + dtype=self._helper.input_dtype(input), stop_gradient=True + ) + + batch_norm_out = ( + input + if self._in_place + else self._helper.create_variable_for_type_inference(self._dtype) + ) + + outputs = { + "Y": [batch_norm_out], + "MeanOut": [mean_out], + "VarianceOut": [variance_out], + "SavedMean": [saved_mean], + "SavedVariance": [saved_variance], + } + if reserve_space is not None: + outputs["ReserveSpace"] = [reserve_space] + + self._helper.append_op( + type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs + ) + + # Currently, we don't support inplace in dygraph mode + return self._helper.append_activation(batch_norm_out, self._act) + + class BatchNorm1D(_BatchNormBase): r""" Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D inputswith additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .