From a6daa82191beadedc618715c57e8d7b5dd437bc1 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Wed, 12 Aug 2020 15:24:14 +0800 Subject: [PATCH 01/10] leaky_relu and LeakyReLU: alpha->negative_slope --- paddle/fluid/operators/activation_op.cc | 8 +- paddle/fluid/operators/activation_op.h | 28 ++--- python/paddle/fluid/layers/nn.py | 19 +--- .../tests/unittests/test_activation_op.py | 101 +++++++++++++++--- python/paddle/nn/functional/activation.py | 53 ++++++++- python/paddle/nn/layer/activation.py | 18 ++-- 6 files changed, 172 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 7ea78879e1e08..c8778a9c96d19 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -753,8 +753,8 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker { } }; -// leaky_relu Grad: dx=dy if y>=0 else alpha * dy -// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx +// leaky_relu Grad: dx=dy if x>=0 else alpha * dy +// leaky_relu GradGrad: ddy=ddx if x>=0 else alpha * ddx template class LeakyReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker { @@ -764,8 +764,8 @@ class LeakyReluDoubleGradMaker protected: void Apply(GradOpPtr op) const override { op->SetType("leaky_relu_grad_grad"); - // input1: Out - op->SetInput("Out", this->Input("Out")); + // input1: X + op->SetInput("X", this->Input("X")); // X@GRAD@GRAD: ddx op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); op->SetAttrMap(this->Attrs()); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 3aac7ae8a5e8a..ab057fba009b7 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1070,7 +1070,11 @@ struct LeakyReluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.cwiseMax(static_cast(alpha) * x); + if (alpha < 1.f) { + out.device(d) = x.cwiseMax(static_cast(alpha) * x); + } else { + out.device(d) = x.cwiseMin(static_cast(alpha) * x); + } } }; @@ -1084,12 +1088,12 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor { typename dX> void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = - static_cast(alpha) * (out <= static_cast(0)).template cast(); - auto temp2 = (out > static_cast(0)).template cast(); + static_cast(alpha) * (x < static_cast(0)).template cast(); + auto temp2 = (x >= static_cast(0)).template cast(); dx.device(d) = dout * (temp1 + temp2).template cast(); } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -1437,18 +1441,18 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad")); - auto out = framework::EigenVector::Flatten( - GET_DATA_SAFELY(Out, "Output", "Out", "LeakyReluGradGrad")); + auto x = framework::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "LeakyReluGradGrad")); auto ddout = framework::EigenVector::Flatten( GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad")); - ddout.device(*d) = ddx * - ((out > static_cast(0)).template cast() + - static_cast(alpha) * - (out <= static_cast(0)).template cast()) - .template cast(); + ddout.device(*d) = + ddx * + ((x > static_cast(0)).template cast() + + static_cast(alpha) * (x <= static_cast(0)).template cast()) + .template cast(); } } - static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d23f20e1b3d4b..edd155147e3f7 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9735,13 +9735,10 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): return out +@deprecated(since="2.0.0", update_to="paddle.nn.functional.leaky_relu") @templatedoc() def leaky_relu(x, alpha=0.02, name=None): """ - :alias_main: paddle.nn.functional.leaky_relu - :alias: paddle.nn.functional.leaky_relu,paddle.nn.functional.activation.leaky_relu - :old_api: paddle.fluid.layers.leaky_relu - ${comment} Args: x(${x_type}): ${x_comment} @@ -9770,19 +9767,7 @@ def leaky_relu(x, alpha=0.02, name=None): res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res]) print(res_val) # [[-0.1, 2], [3, -0.4]] """ - if in_dygraph_mode(): - return core.ops.leaky_relu(x, 'alpha', alpha) - - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], - 'leaky_relu') - - inputs = {'X': [x]} - attrs = {'alpha': alpha} - helper = LayerHelper('leaky_relu', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type='leaky_relu', inputs=inputs, outputs={'Out': out}, attrs=attrs) - return out + return paddle.nn.functional.leaky_relu(x, alpha, name) def soft_relu(x, threshold=40.0, name=None): diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 124767a3364b0..3ebf71d7df6ad 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -22,7 +22,7 @@ import paddle import paddle.fluid as fluid import paddle.nn as nn -import paddle.nn.functional as functional +import paddle.nn.functional as F from paddle.fluid import compiler, Program, program_guard @@ -617,18 +617,29 @@ def test_errors(self): fluid.layers.relu(x_fp16) +def ref_leaky_relu(x, alpha=0.01): + out = np.copy(x) + out[out < 0] *= alpha + return out + + class TestLeakyRelu(TestActivation): + def get_alpha(self): + return 0.01 + def setUp(self): self.op_type = "leaky_relu" self.init_dtype() + alpha = self.get_alpha() x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) # The same reason with TestAbs - x[np.abs(x) < 0.005] = 0.02 - out = np.maximum(x, 0.02 * x) + x[np.abs(x) < 0.005] = 0.05 + out = ref_leaky_relu(x, alpha) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.inputs = {'X': x} self.outputs = {'Out': out} + self.attrs = {'alpha': alpha} def test_check_grad(self): if self.dtype == np.float16: @@ -636,18 +647,78 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') -class TestLeakyReluOpError(unittest.TestCase): +class TestLeakyReluAlpha1(TestLeakyRelu): + def get_alpha(self): + return 2 + + +class TestLeakyReluAlpha2(TestLeakyRelu): + def get_alpha(self): + return -0.01 + + +class TestLeakyReluAlpha3(TestLeakyRelu): + def get_alpha(self): + return -2.0 + + +class TestLeakyReluAPI(unittest.TestCase): + # test paddle.nn.LeakyReLU, paddle.nn.functional.leaky_relu, + # fluid.layers.leaky_relu + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [10, 12]) + out1 = F.leaky_relu(x) + m = paddle.nn.LeakyReLU() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_leaky_relu(self.x_np) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_variable(self.x_np) + out1 = F.leaky_relu(x) + m = paddle.nn.LeakyReLU() + out2 = m(x) + out_ref = ref_leaky_relu(self.x_np) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + out1 = F.leaky_relu(x, 0.6) + m = paddle.nn.LeakyReLU(0.6) + out2 = m(x) + out_ref = ref_leaky_relu(self.x_np, 0.6) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', [10, 12]) + out = fluid.layers.leaky_relu(x, 0.01) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_leaky_relu(self.x_np) + self.assertEqual(np.allclose(out_ref, res[0]), True) + def test_errors(self): - with program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.leaky_relu, 1) + self.assertRaises(TypeError, F.leaky_relu, 1) # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.leaky_relu, x_int32) - # support the input dtype is float32 - x_fp16 = fluid.layers.data( - name='x_fp16', shape=[12, 10], dtype='float32') - fluid.layers.leaky_relu(x_fp16) + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.leaky_relu, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.leaky_relu(x_fp16) def gelu(x, approximate): @@ -1435,7 +1506,7 @@ def test_check_api(self): main_program = Program() with fluid.program_guard(main_program): x = fluid.data(name='x', shape=self.x_shape) - y = functional.relu(x) + y = F.relu(x) exe = fluid.Executor(fluid.CPUPlace()) out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) self.assertTrue(np.allclose(out[0], self.y)) @@ -1501,7 +1572,7 @@ def test_check_api(self): main_program = Program() with fluid.program_guard(main_program): x = fluid.data(name='x', shape=self.x_shape) - y = functional.sigmoid(x) + y = F.sigmoid(x) exe = fluid.Executor(fluid.CPUPlace()) out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) self.assertTrue(np.allclose(out[0], self.y)) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index f524d74f408c0..c6e0a75b77f43 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -20,7 +20,6 @@ from ...fluid.layers import hard_shrink #DEFINE_ALIAS from ...fluid.layers import hard_sigmoid #DEFINE_ALIAS from ...fluid.layers import hard_swish #DEFINE_ALIAS -from ...fluid.layers import leaky_relu #DEFINE_ALIAS from ...fluid.layers import logsigmoid #DEFINE_ALIAS from ...fluid.layers import maxout #DEFINE_ALIAS from ...fluid.layers import relu6 #DEFINE_ALIAS @@ -192,6 +191,58 @@ def hsigmoid(input, return out +def leaky_relu(x, negative_slope=0.01, name=None): + """ + leaky_relu activation + + .. math: + + leaky_relu(x)= + \left\{ + \begin{aligned} + &x, & & if \ x >= 0 \\ + &negative_slope * x, & & otherwise \\ + \end{aligned} + \right. \\ + + Args: + x (Tensor): The input Tensor with data type float32, float64. + negative_slope (float, optional): Slope of the activation function at + :math:`x < 0` . Default is 0.01. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_variable(np.array([-2, 0, 1])) + out = F.leaky_relu(x) # [-0.02, 0., 1.] + """ + if in_dygraph_mode(): + return core.ops.leaky_relu(x, 'alpha', negative_slope) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'leaky_relu') + helper = LayerHelper('leaky_relu', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='leaky_relu', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'alpha': negative_slope}) + return out + + def relu(input, inplace=False, name=None): """ :alias_main: paddle.nn.functional.relu diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index d13f36a31854a..fa7a889593dc9 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -214,11 +214,17 @@ class LeakyReLU(layers.Layer): .. math: - out = max(x, alpha * x) + LeakyReLU(x)= + \left\{ + \begin{aligned} + &x, & & if \ x >= 0 \\ + &negative_slope * x, & & otherwise \\ + \end{aligned} + \right. \\ Parameters: - alpha (float, optional): Slope of the activation function at :math:`x < 0` . - Default: 0.01. + negative_slope (float, optional): Slope of the activation function at + :math:`x < 0` . Default is 0.01. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -239,13 +245,13 @@ class LeakyReLU(layers.Layer): out = lrelu(x) # [-0.02, 0., 1.] """ - def __init__(self, alpha=1e-2, name=None): + def __init__(self, negative_slope=0.01, name=None): super(LeakyReLU, self).__init__() - self._alpha = alpha + self._negative_slope = negative_slope self._name = name def forward(self, x): - return functional.leaky_relu(x, self._alpha, self._name) + return functional.leaky_relu(x, self._negative_slope, self._name) class Sigmoid(layers.Layer): From 71d04b47a9c4e7b5993edf77b737e89ab6b348d8 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 13 Aug 2020 14:12:14 +0800 Subject: [PATCH 02/10] fix --- .../test_leaky_relu_grad_grad_functor.h | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h b/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h index f416aa6e00f5a..182e56c25205c 100644 --- a/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h +++ b/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h @@ -41,12 +41,12 @@ static void InitRandom(framework::Tensor *tensor, template struct LeakyReluGradGradEachElementFunctor { - LeakyReluGradGradEachElementFunctor(const T *ddx, const T *out, T alpha, + LeakyReluGradGradEachElementFunctor(const T *ddx, const T *x, T alpha, T *ddout) - : ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {} + : ddx_(ddx), x_(x), alpha_(alpha), ddout_(ddout) {} HOSTDEVICE void operator()(int idx) { - if (out_[idx] > 0) { + if (x_[idx] >= 0) { ddout_[idx] = ddx_[idx]; } else { ddout_[idx] = ddx_[idx] * alpha_; @@ -54,7 +54,7 @@ struct LeakyReluGradGradEachElementFunctor { } const T *ddx_; - const T *out_; + const T *x_; T alpha_; T *ddout_; }; @@ -66,13 +66,13 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim, LeakyReluGradGradFunctor functor; functor.alpha = alpha; auto &dev_ctx = *platform::DeviceContextPool::Instance().Get(place); - framework::Tensor *x = nullptr; + framework::Tensor *out = nullptr; framework::Tensor *dout = nullptr; framework::Tensor *dx = nullptr; - framework::Tensor out; - out.Resize(dim); - InitRandom(&out, place); + framework::Tensor x; + x.Resize(dim); + InitRandom(&x, place); framework::Tensor ddx; ddx.Resize(dim); @@ -85,7 +85,7 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim, framework::Tensor ddout_actual; ddout_actual.mutable_data(dim, place); LeakyReluGradGradEachElementFunctor actual_functor( - ddx.data(), out.data(), static_cast(alpha), + ddx.data(), x.data(), static_cast(alpha), ddout_actual.data()); int64_t limit = out.numel(); @@ -93,14 +93,14 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim, #ifdef __NVCC__ if (platform::is_gpu_place(place)) { auto &cuda_dev_ctx = dynamic_cast(dev_ctx); - functor(cuda_dev_ctx, x, &out, &ddx, &ddout, dout, dx); + functor(cuda_dev_ctx, &x, out, &ddx, &ddout, dout, dx); platform::ForRange for_range(cuda_dev_ctx, limit); for_range(actual_functor); } else { #endif auto &cpu_dev_ctx = dynamic_cast(dev_ctx); - functor(cpu_dev_ctx, x, &out, &ddx, &ddout, dout, dx); + functor(cpu_dev_ctx, &x, out, &ddx, &ddout, dout, dx); platform::ForRange for_range(cpu_dev_ctx, limit); for_range(actual_functor); From c0fd88cd0c7ec4ae877294556207ecf1839dd008 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 13 Aug 2020 15:12:49 +0800 Subject: [PATCH 03/10] fix --- paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h b/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h index 182e56c25205c..cc2fe4cdbdb8f 100644 --- a/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h +++ b/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h @@ -88,7 +88,7 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim, ddx.data(), x.data(), static_cast(alpha), ddout_actual.data()); - int64_t limit = out.numel(); + int64_t limit = x.numel(); #ifdef __NVCC__ if (platform::is_gpu_place(place)) { From f008199d4a3e20f5917f91a81b52521e33182939 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 13 Aug 2020 17:41:58 +0800 Subject: [PATCH 04/10] fix --- .../fluid/tests/unittests/test_layers.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 9da70e85f01c0..017992ecc84e4 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -298,21 +298,6 @@ def test_relu(self): self.assertTrue(np.allclose(static_ret, dy_ret_value)) - def test_leakyrelu(self): - inputs = np.random.uniform(-1, 1, (10, 10)).astype('float32') - with self.static_graph(): - t = layers.data(name='t', shape=[10, 10], dtype='float32') - ret = layers.leaky_relu(t, alpha=0.01) - static_ret = self.get_static_graph_result( - feed={'t': inputs}, fetch_list=[ret])[0] - - with self.dynamic_graph(): - lrelu = paddle.nn.LeakyReLU(alpha=0.01) - dy_ret = lrelu(base.to_variable(inputs)) - dy_ret_value = dy_ret.numpy() - - self.assertTrue(np.allclose(static_ret, dy_ret_value)) - def test_pad2d(self): with self.static_graph(): t = layers.data(name='t', shape=[-1, 3, 5, 5], dtype='float32') @@ -2660,13 +2645,6 @@ def make_brelu(self): out = layers.brelu(input, t_min=1.0, t_max=20.0, name='brelu') return (out) - def make_leaky_relu(self): - with program_guard(fluid.default_main_program(), - fluid.default_startup_program()): - input = self._get_data(name="input", shape=[16], dtype="float32") - out = layers.leaky_relu(input, alpha=0.1, name='leaky_relu') - return (out) - def make_soft_relu(self): with program_guard(fluid.default_main_program(), fluid.default_startup_program()): From b25c5493d44d7eda76296087f744f899c620f811 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Sat, 15 Aug 2020 13:46:29 +0800 Subject: [PATCH 05/10] fix --- python/paddle/nn/layer/activation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 89efb3a99962f..589b4c60a8b26 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -288,9 +288,9 @@ class LeakyReLU(layers.Layer): paddle.disable_static() - lrelu = paddle.nn.LeakyReLU() + m = paddle.nn.LeakyReLU() x = paddle.to_variable(np.array([-2, 0, 1], 'float32')) - out = lrelu(x) # [-0.02, 0., 1.] + out = m(x) # [-0.02, 0., 1.] """ def __init__(self, negative_slope=0.01, name=None): From 3fad76358f71438a2295348d5fe2737452b0a7bf Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Sat, 15 Aug 2020 16:04:53 +0800 Subject: [PATCH 06/10] fix --- python/paddle/fluid/tests/unittests/test_activation_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 7ef7a23578f28..7c4505e180391 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -682,6 +682,7 @@ def setUp(self): self.init_dtype() alpha = self.get_alpha() + np.random.seed(10) x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) # The same reason with TestAbs x[np.abs(x) < 0.005] = 0.05 From e865f969d5b84ebb1971a99be86af311c0655268 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Mon, 17 Aug 2020 15:18:52 +0800 Subject: [PATCH 07/10] fix --- python/paddle/fluid/tests/unittests/test_activation_op.py | 2 +- python/paddle/nn/layer/activation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 0b9bdfa19e9af..a5e8ae82e6fb4 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -675,7 +675,7 @@ def ref_leaky_relu(x, alpha=0.01): class TestLeakyRelu(TestActivation): def get_alpha(self): - return 0.01 + return 0.02 def setUp(self): self.op_type = "leaky_relu" diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 50a058dc1c46e..c2ad52d065d54 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -299,7 +299,7 @@ def __init__(self, negative_slope=0.01, name=None): self._name = name def forward(self, x): - return F.leaky_relu(x, self._alpha, self._name) + return F.leaky_relu(x, self._negative_slope, self._name) class Sigmoid(layers.Layer): From 62d4aad619be0f721e17a830b6ca9e5a7816b020 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 18 Aug 2020 20:48:34 +0800 Subject: [PATCH 08/10] fix --- python/paddle/nn/functional/activation.py | 2 +- python/paddle/nn/layer/activation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 7bd7b761d1bc9..b02aa769dffc9 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -254,7 +254,7 @@ def leaky_relu(x, negative_slope=0.01, name=None): \left\{ \begin{aligned} &x, & & if \ x >= 0 \\ - &negative_slope * x, & & otherwise \\ + &negative\_slope * x, & & otherwise \\ \end{aligned} \right. \\ diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index c2ad52d065d54..d37cb97094a52 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -266,7 +266,7 @@ class LeakyReLU(layers.Layer): \left\{ \begin{aligned} &x, & & if \ x >= 0 \\ - &negative_slope * x, & & otherwise \\ + &negative\_slope * x, & & otherwise \\ \end{aligned} \right. \\ From 05f1d0d377dab7689ea0200683fcd55f97b66595 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 20 Aug 2020 12:55:19 +0800 Subject: [PATCH 09/10] fix --- python/paddle/nn/functional/activation.py | 170 +++++++++++----------- python/paddle/nn/layer/activation.py | 12 +- 2 files changed, 91 insertions(+), 91 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 2d4f121b1d6bb..d31679314f0d2 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -85,16 +85,16 @@ def elu(x, alpha=1.0, name=None): Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) - out = F.elu(x, alpha=0.2) - # [[-0.12642411 6. ] - # [ 1. 15.6 ]] + x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) + out = F.elu(x, alpha=0.2) + # [[-0.12642411 6. ] + # [ 1. 15.6 ]] """ if in_dygraph_mode(): @@ -134,23 +134,23 @@ def gelu(x, approximate=False, name=None): Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - data = np.random.randn(2, 3).astype("float32") - x = paddle.to_tensor(data) + data = np.random.randn(2, 3).astype("float32") + x = paddle.to_tensor(data) - out = F.gelu(x) + out = F.gelu(x) - data - # array([[ 0.87165993, -1.0541513 , -0.37214822], - # [ 0.15647964, 0.32496083, 0.33045998]], dtype=float32) - out - # array([[ 0.70456535, -0.15380788, -0.13207214], - # [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32) + data + # array([[ 0.87165993, -1.0541513 , -0.37214822], + # [ 0.15647964, 0.32496083, 0.33045998]], dtype=float32) + out + # array([[ 0.70456535, -0.15380788, -0.13207214], + # [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32) """ if in_dygraph_mode(): @@ -195,14 +195,14 @@ def hardshrink(x, threshold=0.5, name=None): .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_variable(np.array([-1, 0.3, 2.5])) - out = F.hardshrink(x) # [-1., 0., 2.5] + x = paddle.to_variable(np.array([-1, 0.3, 2.5])) + out = F.hardshrink(x) # [-1., 0., 2.5] """ if in_dygraph_mode(): @@ -371,14 +371,14 @@ def leaky_relu(x, negative_slope=0.01, name=None): .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([-2, 0, 1])) - out = F.leaky_relu(x) # [-0.02, 0., 1.] + x = paddle.to_tensor(np.array([-2, 0, 1])) + out = F.leaky_relu(x) # [-0.02, 0., 1.] """ if in_dygraph_mode(): return core.ops.leaky_relu(x, 'alpha', negative_slope) @@ -414,14 +414,14 @@ def relu(x, name=None): Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32')) - out = F.relu(x) # [0., 0., 1.] + x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32')) + out = F.relu(x) # [0., 0., 1.] """ if in_dygraph_mode(): @@ -453,14 +453,14 @@ def logsigmoid(x, name=None): Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0])) - out = F.logsigmoid(x) # [0.7310586, 0.880797, 0.95257413, 0.98201376] + x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0])) + out = F.logsigmoid(x) # [0.7310586, 0.880797, 0.95257413, 0.98201376] """ if in_dygraph_mode(): @@ -566,26 +566,26 @@ def softmax(x, axis=-1, name=None): .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = np.array([[[2.0, 3.0, 4.0, 5.0], - [3.0, 4.0, 5.0, 6.0], - [7.0, 8.0, 8.0, 9.0]], - [[1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [6.0, 7.0, 8.0, 9.0]]], 'float32') - x = paddle.to_tensor(x) - out = F.softmax(x) - # [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426], - # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], - # [0.07232949, 0.19661193, 0.19661193, 0.53444665]], - # [[0.0320586 , 0.08714432, 0.23688282, 0.64391426], - # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], - # [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]] + x = np.array([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]], 'float32') + x = paddle.to_tensor(x) + out = F.softmax(x) + # [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.07232949, 0.19661193, 0.19661193, 0.53444665]], + # [[0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]] """ return paddle.fluid.layers.softmax(input=x, axis=axis, name=name) @@ -622,29 +622,29 @@ def log_softmax(x, axis=-1, dtype=None, name=None): Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np - - paddle.disable_static() + import paddle + import paddle.nn.functional as F + import numpy as np - x = np.array([[[-2.0, 3.0, -4.0, 5.0], - [3.0, -4.0, 5.0, -6.0], - [-7.0, -8.0, 8.0, 9.0]], - [[1.0, -2.0, -3.0, 4.0], - [-5.0, 6.0, 7.0, -8.0], - [6.0, 7.0, 8.0, 9.0]]], 'float32') - x = paddle.to_tensor(x) - out1 = F.log_softmax(x) - out2 = F.log_softmax(x, dtype='float64') - # out1's data type is float32; out2's data type is float64 - # out1 and out2's value is as follows: - # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] - # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] - # [-16.313261 -17.313261 -1.3132617 -0.31326184]] - # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] - # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] - # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] + paddle.disable_static() + + x = np.array([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]], 'float32') + x = paddle.to_tensor(x) + out1 = F.log_softmax(x) + out2 = F.log_softmax(x, dtype='float64') + # out1's data type is float32; out2's data type is float64 + # out1 and out2's value is as follows: + # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] + # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] + # [-16.313261 -17.313261 -1.3132617 -0.31326184]] + # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] + # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] + # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] """ if axis is None: diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 6716813221841..3c90b28c923a8 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -378,14 +378,14 @@ class LeakyReLU(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - m = paddle.nn.LeakyReLU() - x = paddle.to_tensor(np.array([-2, 0, 1])) - out = m(x) # [-0.02, 0., 1.] + m = paddle.nn.LeakyReLU() + x = paddle.to_tensor(np.array([-2, 0, 1])) + out = m(x) # [-0.02, 0., 1.] """ def __init__(self, negative_slope=0.01, name=None): From 05beeec3aceeb83402df09422e6bea3b0396078d Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 20 Aug 2020 13:19:19 +0800 Subject: [PATCH 10/10] fix --- python/paddle/nn/layer/activation.py | 136 +++++++++++++-------------- 1 file changed, 68 insertions(+), 68 deletions(-) diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 3c90b28c923a8..1386eb6fff71e 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -54,16 +54,16 @@ class ELU(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) - m = paddle.nn.ELU(0.2) - out = m(x) - # [[-0.12642411 6. ] - # [ 1. 15.6 ]] + x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) + m = paddle.nn.ELU(0.2) + out = m(x) + # [[-0.12642411 6. ] + # [ 1. 15.6 ]] """ def __init__(self, alpha=1.0, name=None): @@ -103,23 +103,23 @@ class GELU(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - data = np.random.randn(2, 3).astype("float32") - x = paddle.to_tensor(data) + data = np.random.randn(2, 3).astype("float32") + x = paddle.to_tensor(data) - m = paddle.nn.GELU() - out = m(x) + m = paddle.nn.GELU() + out = m(x) - data - # array([[ 0.87165993, -1.0541513 , -0.37214822], - # [ 0.15647964, 0.32496083, 0.33045998]], dtype=float32) - out - # array([[ 0.70456535, -0.15380788, -0.13207214], - # [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32) + data + # array([[ 0.87165993, -1.0541513 , -0.37214822], + # [ 0.15647964, 0.32496083, 0.33045998]], dtype=float32) + out + # array([[ 0.70456535, -0.15380788, -0.13207214], + # [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32) """ def __init__(self, approximate=False, name=None): @@ -159,14 +159,14 @@ class Hardshrink(layers.Layer): .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_variable(np.array([-1, 0.3, 2.5])) - m = paddle.nn.Hardshrink() - out = m(x) # [-1., 0., 2.5] + x = paddle.to_tensor(np.array([-1, 0.3, 2.5])) + m = paddle.nn.Hardshrink() + out = m(x) # [-1., 0., 2.5] """ def __init__(self, threshold=0.5, name=None): @@ -333,14 +333,14 @@ class ReLU(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32')) - m = paddle.nn.ReLU() - out = m(x) # [0., 0., 1.] + x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32')) + m = paddle.nn.ReLU() + out = m(x) # [0., 0., 1.] """ def __init__(self, name=None): @@ -418,15 +418,15 @@ class Sigmoid(layers.Layer): .. code-block:: python - import numpy as np - import paddle - - paddle.disable_static() - input_data = np.array([1.0, 2.0, 3.0, 4.0]).astype('float32') - m = paddle.nn.Sigmoid() - x = paddle.to_variable(input_data) - output = m(x) - print(output.numpy()) # [0.7310586, 0.880797, 0.95257413, 0.98201376] + import numpy as np + import paddle + + paddle.disable_static() + input_data = np.array([1.0, 2.0, 3.0, 4.0]).astype('float32') + m = paddle.nn.Sigmoid() + x = paddle.to_tensor(input_data) + output = m(x) + print(output.numpy()) # [0.7310586, 0.880797, 0.95257413, 0.98201376] """ def __init__(self, name=None): @@ -457,14 +457,14 @@ class LogSigmoid(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0])) - m = paddle.nn.LogSigmoid() - out = m(x) # [0.7310586, 0.880797, 0.95257413, 0.98201376] + x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0])) + m = paddle.nn.LogSigmoid() + out = m(x) # [0.7310586, 0.880797, 0.95257413, 0.98201376] """ def __init__(self, name=None): @@ -499,26 +499,26 @@ class LogSoftmax(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np - - paddle.disable_static() - - x = np.array([[[-2.0, 3.0, -4.0, 5.0], - [3.0, -4.0, 5.0, -6.0], - [-7.0, -8.0, 8.0, 9.0]], - [[1.0, -2.0, -3.0, 4.0], - [-5.0, 6.0, 7.0, -8.0], - [6.0, 7.0, 8.0, 9.0]]]) - m = paddle.nn.LogSoftmax() - x = paddle.to_tensor(x) - out = m(x) - # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] - # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] - # [-16.313261 -17.313261 -1.3132617 -0.31326184]] - # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] - # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] - # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] + import paddle + import numpy as np + + paddle.disable_static() + + x = np.array([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]]) + m = paddle.nn.LogSoftmax() + x = paddle.to_tensor(x) + out = m(x) + # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] + # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] + # [-16.313261 -17.313261 -1.3132617 -0.31326184]] + # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] + # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] + # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] """ def __init__(self, axis=-1, name=None):