diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index e061fd0ab93ab..63b3b0f1a3408 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -781,8 +781,8 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker { } }; -// leaky_relu Grad: dx=dy if y>=0 else alpha * dy -// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx +// leaky_relu Grad: dx=dy if x>=0 else alpha * dy +// leaky_relu GradGrad: ddy=ddx if x>=0 else alpha * ddx template class LeakyReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker { @@ -792,8 +792,8 @@ class LeakyReluDoubleGradMaker protected: void Apply(GradOpPtr op) const override { op->SetType("leaky_relu_grad_grad"); - // input1: Out - op->SetInput("Out", this->Input("Out")); + // input1: X + op->SetInput("X", this->Input("X")); // X@GRAD@GRAD: ddx op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); op->SetAttrMap(this->Attrs()); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 018f515d81e93..b411f0f21da21 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1084,7 +1084,11 @@ struct LeakyReluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.cwiseMax(static_cast(alpha) * x); + if (alpha < 1.f) { + out.device(d) = x.cwiseMax(static_cast(alpha) * x); + } else { + out.device(d) = x.cwiseMin(static_cast(alpha) * x); + } } }; @@ -1098,12 +1102,12 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor { typename dX> void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = - static_cast(alpha) * (out <= static_cast(0)).template cast(); - auto temp2 = (out > static_cast(0)).template cast(); + static_cast(alpha) * (x < static_cast(0)).template cast(); + auto temp2 = (x >= static_cast(0)).template cast(); dx.device(d) = dout * (temp1 + temp2).template cast(); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -1451,18 +1455,18 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad")); - auto out = framework::EigenVector::Flatten( - GET_DATA_SAFELY(Out, "Output", "Out", "LeakyReluGradGrad")); + auto x = framework::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "LeakyReluGradGrad")); auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad")); - ddout.device(*d) = ddx * - ((out > static_cast(0)).template cast() + - static_cast(alpha) * - (out <= static_cast(0)).template cast()) - .template cast(); + ddout.device(*d) = + ddx * + ((x > static_cast(0)).template cast() + + static_cast(alpha) * (x <= static_cast(0)).template cast()) + .template cast(); } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template diff --git a/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h b/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h index f416aa6e00f5a..cc2fe4cdbdb8f 100644 --- a/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h +++ b/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h @@ -41,12 +41,12 @@ static void InitRandom(framework::Tensor *tensor, template struct LeakyReluGradGradEachElementFunctor { - LeakyReluGradGradEachElementFunctor(const T *ddx, const T *out, T alpha, + LeakyReluGradGradEachElementFunctor(const T *ddx, const T *x, T alpha, T *ddout) - : ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {} + : ddx_(ddx), x_(x), alpha_(alpha), ddout_(ddout) {} HOSTDEVICE void operator()(int idx) { - if (out_[idx] > 0) { + if (x_[idx] >= 0) { ddout_[idx] = ddx_[idx]; } else { ddout_[idx] = ddx_[idx] * alpha_; @@ -54,7 +54,7 @@ struct LeakyReluGradGradEachElementFunctor { } const T *ddx_; - const T *out_; + const T *x_; T alpha_; T *ddout_; }; @@ -66,13 +66,13 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim, LeakyReluGradGradFunctor functor; functor.alpha = alpha; auto &dev_ctx = *platform::DeviceContextPool::Instance().Get(place); - framework::Tensor *x = nullptr; + framework::Tensor *out = nullptr; framework::Tensor *dout = nullptr; framework::Tensor *dx = nullptr; - framework::Tensor out; - out.Resize(dim); - InitRandom(&out, place); + framework::Tensor x; + x.Resize(dim); + InitRandom(&x, place); framework::Tensor ddx; ddx.Resize(dim); @@ -85,22 +85,22 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim, framework::Tensor ddout_actual; ddout_actual.mutable_data(dim, place); LeakyReluGradGradEachElementFunctor actual_functor( - ddx.data(), out.data(), static_cast(alpha), + ddx.data(), x.data(), static_cast(alpha), ddout_actual.data()); - int64_t limit = out.numel(); + int64_t limit = x.numel(); #ifdef __NVCC__ if (platform::is_gpu_place(place)) { auto &cuda_dev_ctx = dynamic_cast(dev_ctx); - functor(cuda_dev_ctx, x, &out, &ddx, &ddout, dout, dx); + functor(cuda_dev_ctx, &x, out, &ddx, &ddout, dout, dx); platform::ForRange for_range(cuda_dev_ctx, limit); for_range(actual_functor); } else { #endif auto &cpu_dev_ctx = dynamic_cast(dev_ctx); - functor(cpu_dev_ctx, x, &out, &ddx, &ddout, dout, dx); + functor(cpu_dev_ctx, &x, out, &ddx, &ddout, dout, dx); platform::ForRange for_range(cpu_dev_ctx, limit); for_range(actual_functor); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 51db78d2ace46..1c7b27d93c498 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9715,13 +9715,10 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): return out +@deprecated(since="2.0.0", update_to="paddle.nn.functional.leaky_relu") @templatedoc() def leaky_relu(x, alpha=0.02, name=None): """ - :alias_main: paddle.nn.functional.leaky_relu - :alias: paddle.nn.functional.leaky_relu,paddle.nn.functional.activation.leaky_relu - :old_api: paddle.fluid.layers.leaky_relu - ${comment} Args: x(${x_type}): ${x_comment} @@ -9750,19 +9747,7 @@ def leaky_relu(x, alpha=0.02, name=None): res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res]) print(res_val) # [[-0.1, 2], [3, -0.4]] """ - if in_dygraph_mode(): - return core.ops.leaky_relu(x, 'alpha', alpha) - - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], - 'leaky_relu') - - inputs = {'X': [x]} - attrs = {'alpha': alpha} - helper = LayerHelper('leaky_relu', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type='leaky_relu', inputs=inputs, outputs={'Out': out}, attrs=attrs) - return out + return paddle.nn.functional.leaky_relu(x, alpha, name) def soft_relu(x, threshold=40.0, name=None): diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index e33d4aa6c3631..174fff4acde6f 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -903,18 +903,30 @@ def test_errors(self): F.relu(x_fp16) +def ref_leaky_relu(x, alpha=0.01): + out = np.copy(x) + out[out < 0] *= alpha + return out + + class TestLeakyRelu(TestActivation): + def get_alpha(self): + return 0.02 + def setUp(self): self.op_type = "leaky_relu" self.init_dtype() + alpha = self.get_alpha() + np.random.seed(10) x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) # The same reason with TestAbs - x[np.abs(x) < 0.005] = 0.02 - out = np.maximum(x, 0.02 * x) + x[np.abs(x) < 0.005] = 0.05 + out = ref_leaky_relu(x, alpha) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.inputs = {'X': x} self.outputs = {'Out': out} + self.attrs = {'alpha': alpha} def test_check_grad(self): if self.dtype == np.float16: @@ -922,18 +934,78 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') -class TestLeakyReluOpError(unittest.TestCase): +class TestLeakyReluAlpha1(TestLeakyRelu): + def get_alpha(self): + return 2 + + +class TestLeakyReluAlpha2(TestLeakyRelu): + def get_alpha(self): + return -0.01 + + +class TestLeakyReluAlpha3(TestLeakyRelu): + def get_alpha(self): + return -2.0 + + +class TestLeakyReluAPI(unittest.TestCase): + # test paddle.nn.LeakyReLU, paddle.nn.functional.leaky_relu, + # fluid.layers.leaky_relu + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [10, 12]) + out1 = F.leaky_relu(x) + m = paddle.nn.LeakyReLU() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_leaky_relu(self.x_np) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_variable(self.x_np) + out1 = F.leaky_relu(x) + m = paddle.nn.LeakyReLU() + out2 = m(x) + out_ref = ref_leaky_relu(self.x_np) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + out1 = F.leaky_relu(x, 0.6) + m = paddle.nn.LeakyReLU(0.6) + out2 = m(x) + out_ref = ref_leaky_relu(self.x_np, 0.6) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', [10, 12]) + out = fluid.layers.leaky_relu(x, 0.01) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_leaky_relu(self.x_np) + self.assertEqual(np.allclose(out_ref, res[0]), True) + def test_errors(self): - with program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.leaky_relu, 1) + self.assertRaises(TypeError, F.leaky_relu, 1) # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.leaky_relu, x_int32) - # support the input dtype is float32 - x_fp16 = fluid.layers.data( - name='x_fp16', shape=[12, 10], dtype='float32') - fluid.layers.leaky_relu(x_fp16) + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.leaky_relu, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.leaky_relu(x_fp16) def gelu(x, approximate): diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 91186b2e95ae0..4ce7bd693f3de 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -316,21 +316,6 @@ def test_relu(self): self.assertTrue(np.allclose(static_ret, dy_ret_value)) - def test_leakyrelu(self): - inputs = np.random.uniform(-1, 1, (10, 10)).astype('float32') - with self.static_graph(): - t = layers.data(name='t', shape=[10, 10], dtype='float32') - ret = layers.leaky_relu(t, alpha=0.01) - static_ret = self.get_static_graph_result( - feed={'t': inputs}, fetch_list=[ret])[0] - - with self.dynamic_graph(): - lrelu = paddle.nn.LeakyReLU(alpha=0.01) - dy_ret = lrelu(base.to_variable(inputs)) - dy_ret_value = dy_ret.numpy() - - self.assertTrue(np.allclose(static_ret, dy_ret_value)) - def test_pad2d(self): with self.static_graph(): t = layers.data(name='t', shape=[-1, 3, 5, 5], dtype='float32') @@ -2678,13 +2663,6 @@ def make_brelu(self): out = layers.brelu(input, t_min=1.0, t_max=20.0, name='brelu') return (out) - def make_leaky_relu(self): - with program_guard(fluid.default_main_program(), - fluid.default_startup_program()): - input = self._get_data(name="input", shape=[16], dtype="float32") - out = layers.leaky_relu(input, alpha=0.1, name='leaky_relu') - return (out) - def make_soft_relu(self): with program_guard(fluid.default_main_program(), fluid.default_startup_program()): diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 619492a1b2bb1..44e322c6d4b3d 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -17,7 +17,6 @@ from ...fluid.layers import erf #DEFINE_ALIAS from ...fluid.layers import hard_sigmoid #DEFINE_ALIAS from ...fluid.layers import hard_swish #DEFINE_ALIAS -from ...fluid.layers import leaky_relu #DEFINE_ALIAS from ...fluid.layers import maxout #DEFINE_ALIAS from ...fluid.layers import soft_relu #DEFINE_ALIAS from ...fluid.layers import swish #DEFINE_ALIAS @@ -386,6 +385,57 @@ def hsigmoid(input, return out +def leaky_relu(x, negative_slope=0.01, name=None): + """ + leaky_relu activation + + .. math: + leaky_relu(x)= + \left\{ + \begin{aligned} + &x, & & if \ x >= 0 \\ + &negative\_slope * x, & & otherwise \\ + \end{aligned} + \right. \\ + + Args: + x (Tensor): The input Tensor with data type float32, float64. + negative_slope (float, optional): Slope of the activation function at + :math:`x < 0` . Default is 0.01. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-2, 0, 1])) + out = F.leaky_relu(x) # [-0.02, 0., 1.] + + """ + if in_dygraph_mode(): + return core.ops.leaky_relu(x, 'alpha', negative_slope) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'leaky_relu') + helper = LayerHelper('leaky_relu', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='leaky_relu', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'alpha': negative_slope}) + return out + + def prelu(x, weight, name=None): """ prelu activation. diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index bfa2e2f25d873..b9cc13fa85bd1 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -558,11 +558,17 @@ class LeakyReLU(layers.Layer): .. math: - out = max(x, alpha * x) + LeakyReLU(x)= + \left\{ + \begin{aligned} + &x, & & if \ x >= 0 \\ + &negative\_slope * x, & & otherwise \\ + \end{aligned} + \right. \\ Parameters: - alpha (float, optional): Slope of the activation function at :math:`x < 0` . - Default: 0.01. + negative_slope (float, optional): Slope of the activation function at + :math:`x < 0` . Default is 0.01. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -573,23 +579,23 @@ class LeakyReLU(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - lrelu = paddle.nn.LeakyReLU() - x = paddle.to_tensor(np.array([-2, 0, 1], 'float32')) - out = lrelu(x) # [-0.02, 0., 1.] + m = paddle.nn.LeakyReLU() + x = paddle.to_tensor(np.array([-2, 0, 1])) + out = m(x) # [-0.02, 0., 1.] """ - def __init__(self, alpha=1e-2, name=None): + def __init__(self, negative_slope=0.01, name=None): super(LeakyReLU, self).__init__() - self._alpha = alpha + self._negative_slope = negative_slope self._name = name def forward(self, x): - return F.leaky_relu(x, self._alpha, self._name) + return F.leaky_relu(x, self._negative_slope, self._name) class Sigmoid(layers.Layer):