diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index efa60b70001e5..162ad5887cc87 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1189,6 +1189,7 @@ def chunk_eval(input, num_correct_chunks) +@deprecated(since="2.0.0", update_to="paddle.nn.functional.softmax") def softmax(input, use_cudnn=False, name=None, axis=-1): """ This operator implements the softmax layer. The calculation process is as follows: @@ -8601,7 +8602,7 @@ def log(x, name=None): return out -@templatedoc() +@deprecated(since="2.0.0", update_to="paddle.nn.functional.relu") def relu(x, name=None): """ ${comment} @@ -9260,7 +9261,7 @@ def pad2d(input, return out -@templatedoc() +@deprecated(since="2.0.0", update_to="paddle.nn.functional.elu") def elu(x, alpha=1.0, name=None): """ :alias_main: paddle.nn.functional.elu @@ -9576,6 +9577,7 @@ def swish(x, beta=1.0, name=None): return out +@deprecated(since="2.0.0", update_to="paddle.nn.functional.prelu") def prelu(x, mode, param_attr=None, name=None): """ :api_attr: Static Graph diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index db32976f046d3..bf10bcd084ba7 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -525,6 +525,63 @@ def test_errors(self): F.hardshrink(x_fp16) +def ref_hardtanh(x, min=-1.0, max=1.0): + out = np.copy(x) + out[np.abs(x - min) < 0.005] = min + 0.02 + out[np.abs(x - max) < 0.005] = max + 0.02 + out = np.minimum(np.maximum(x, min), max) + return out + + +class TestHardtanhAPI(unittest.TestCase): + # test paddle.nn.Hardtanh, paddle.nn.functional.hardtanh + def setUp(self): + self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [10, 12]) + out1 = F.hardtanh(x) + m = paddle.nn.Hardtanh() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_hardtanh(self.x_np) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_variable(self.x_np) + out1 = F.hardtanh(x) + m = paddle.nn.Hardtanh() + out2 = m(x) + out_ref = ref_hardtanh(self.x_np) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + out1 = F.hardtanh(x, -2.0, 2.0) + m = paddle.nn.Hardtanh(-2.0, 2.0) + out2 = m(x) + out_ref = ref_hardtanh(self.x_np, -2.0, 2.0) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.hardtanh, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.hardtanh, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16') + F.hardtanh(x_fp16) + + def ref_softshrink(x, threshold=0.5): out = np.copy(x) out = (out < -threshold) * (out + threshold) + (out > threshold) * ( diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index 0a38bd277bfd1..16388ff8f5f04 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -18,23 +18,134 @@ import numpy as np import paddle.fluid as fluid import six -import paddle.fluid as fluid +import paddle.fluid.core as core from paddle.fluid import Program, program_guard from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.nn.functional as F + + +def ref_prelu(x, weight): + x_t = x.copy() + weight = weight.reshape(1, -1, 1, 1) + neg_indices = x <= 0 + assert x.shape == neg_indices.shape + x_t[neg_indices] = (x_t * weight)[neg_indices] + return (x_t, ) + + +def ref_prelu_nn(x, num_parameters, init): + weight_np = np.full((num_parameters), init) + return ref_prelu(x, weight_np) -class TestPReluOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program()): +class TestFunctionalPReluAPI(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + self.x_np = np.random.uniform(-1., 1., [1, 2, 3, 4]).astype('float32') + self.weight_np_0 = np.random.randn(1).astype('float32') + self.weight_np_1 = np.random.randn(self.x_np.shape[1]).astype('float32') + + def static_check(self, weight_np): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_np.shape, 'float32') + weight = paddle.data('Alpha', weight_np.shape, 'float32') + out = F.prelu(x, weight) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np, + 'Alpha': weight_np}, + fetch_list=[out]) + out_ref = ref_prelu(self.x_np, weight_np) + self.assertEqual(np.allclose(out_ref, res[0]), True) + + def dygraph_check(self, weight_np): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + weight = paddle.to_tensor(weight_np) + out = F.prelu(x, weight) + out_ref = ref_prelu(self.x_np, weight_np) + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + paddle.enable_static() + + def test_static_api(self): + self.static_check(self.weight_np_0) + self.static_check(self.weight_np_1) + + def test_dygraph_api(self): + self.dygraph_check(self.weight_np_0) + self.dygraph_check(self.weight_np_1) + + def test_error(self): + with paddle.static.program_guard(paddle.static.Program()): + weight_fp32 = paddle.data( + name='weight_fp32', shape=[1], dtype='float32') # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.prelu, 0.1, 'all') + self.assertRaises(TypeError, F.prelu, x=1, weight=weight_fp32) # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.prelu, x_int32, 'all') - # support the input dtype is float32 - x_fp16 = fluid.layers.data( - name='x_fp16', shape=[12, 10], dtype='float32') - fluid.layers.prelu(x_fp16, 'all') + x_int32 = paddle.data(name='x_int32', shape=[2, 3], dtype='int32') + self.assertRaises(TypeError, F.prelu, x=x_int32, weight=weight_fp32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[2, 3], dtype='float16') + F.prelu(x=x_fp16, weight=weight_fp32) + + +class TestNNPReluAPI(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + self.x_np = np.ones([1, 2, 3, 4]).astype('float32') + + def test_static_api(self): + startup_program = paddle.static.Program() + train_program = paddle.static.Program() + with paddle.static.program_guard(train_program, startup_program): + x = paddle.data(name='X', shape=self.x_np.shape, dtype='float32') + m = paddle.nn.PReLU() + out = m(x) + exe = paddle.static.Executor(self.place) + exe.run(startup_program) + res = exe.run(train_program, + feed={'X': self.x_np}, + fetch_list=[out]) + out_ref = ref_prelu_nn(self.x_np, 1, 0.25) + self.assertEqual(np.allclose(out_ref, res[0]), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + + x = paddle.to_tensor(self.x_np) + m = paddle.nn.PReLU() + out = m(x) + out_ref = ref_prelu_nn(self.x_np, 1, 0.25) + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + + x = paddle.to_tensor(self.x_np) + m = paddle.nn.PReLU(num_parameters=self.x_np.shape[1]) + out = m(x) + out_ref = ref_prelu_nn(self.x_np, self.x_np.shape[1], 0.25) + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + + x = paddle.to_tensor(self.x_np) + m = paddle.nn.PReLU(init=0.5) + out = m(x) + out_ref = ref_prelu_nn(self.x_np, 1, 0.5) + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + + x = paddle.to_tensor(self.x_np) + m = paddle.nn.PReLU(weight_attr=fluid.ParamAttr(name="weight")) + out = m(x) + out_ref = ref_prelu_nn(self.x_np, 1, 0.25) + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + + x = paddle.to_tensor(self.x_np) + m = paddle.nn.PReLU(weight_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.5))) + out = m(x) + out_ref = ref_prelu_nn(self.x_np, 1, 0.5) + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + + paddle.enable_static() class PReluTest(OpTest): diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 25e95216968b5..04d5cc941a463 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -35,6 +35,15 @@ def stable_softmax(x): return exps / np.sum(exps) +def ref_softmax(x, axis=None, dtype=None): + x_t = x.copy() + if dtype is not None: + x_t = x_t.astype(dtype) + if axis is None: + axis = -1 + return np.apply_along_axis(stable_softmax, axis, x_t) + + class TestSoftmaxOp(OpTest): def get_x_shape(self): return [10, 10] @@ -93,20 +102,6 @@ def test_check_grad(self): check_dygraph=(self.use_mkldnn == False)) -class TestSoftmaxOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - # The input type of softmax_op must be Variable. - x1 = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.CPUPlace()) - self.assertRaises(TypeError, fluid.layers.softmax, x1) - # The input dtype of softmax_op must be float16, float32 or float64. - x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32") - self.assertRaises(TypeError, fluid.layers.softmax, x2) - x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16") - fluid.layers.softmax(x3) - - class TestSoftmaxOp2(TestSoftmaxOp): def get_x_shape(self): return [2, 3, 4, 5] @@ -224,41 +219,59 @@ def get_x_shape(self): return [2, 3, 4, 5] -class TestNnFunctionalSoftmaxApi(unittest.TestCase): +class TestSoftmaxAPI(unittest.TestCase): def setUp(self): self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( ) else paddle.CPUPlace() self.x_np = np.random.uniform(-1., 1., [2, 3, 4, 5]).astype('float32') self.out_ref = np.apply_along_axis(stable_softmax, -1, self.x_np) - def test_api_static(self): - with program_guard(Program()): + def test_static_check(self): + with paddle.static.program_guard(paddle.static.Program()): x = paddle.data('X', self.x_np.shape, 'float32') - out = F.softmax(x) + out1 = F.softmax(x) + m = paddle.nn.Softmax() + out2 = m(x) exe = paddle.static.Executor(self.place) - res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) - self.assertEqual(np.allclose(self.out_ref, res[0]), True) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_softmax(self.x_np, axis=-1, dtype=None) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) - def test_api_imperative(self): + def test_dygraph_check(self): paddle.disable_static(self.place) - x = paddle.to_variable(self.x_np) - out = F.softmax(x) - self.assertEqual(np.allclose(self.out_ref, out.numpy()), True) - - out = F.softmax(x, axis=0) - out_ref = np.apply_along_axis(stable_softmax, 0, self.x_np) + x = paddle.to_tensor(self.x_np) + out1 = F.softmax(x) + m = paddle.nn.Softmax() + out2 = m(x) + out_ref = ref_softmax(self.x_np, axis=-1, dtype=None) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + out1 = F.softmax(x, axis=0) + m = paddle.nn.Softmax(axis=0) + out2 = m(x) + out_ref = ref_softmax(self.x_np, axis=0, dtype=None) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + out = F.softmax(x, dtype=np.float64) + out_ref = ref_softmax(self.x_np, axis=-1, dtype=np.float64) self.assertEqual(np.allclose(out_ref, out.numpy()), True) paddle.enable_static() def test_error(self): - with program_guard(Program(), Program()): - # The x should be variable and its dtype should be float32, float64. - self.assertRaises(TypeError, F.softmax, [1]) - - x = paddle.data(name='x', shape=[2, 3], dtype='int32') - self.assertRaises(TypeError, F.softmax, x) + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.softmax, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.data(name='x_int32', shape=[2, 3], dtype='int32') + self.assertRaises(TypeError, F.softmax, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[2, 3], dtype='float16') + F.softmax(x_fp16) if __name__ == "__main__": diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 3dd1c1d94fbd7..84c466c977eea 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -55,14 +55,15 @@ from .layer.activation import ELU from .layer.activation import GELU from .layer.activation import Hardshrink -# from .layer.activation import PReLU #DEFINE_ALIAS +from .layer.activation import Hardtanh +from .layer.activation import PReLU from .layer.activation import ReLU from .layer.activation import ReLU6 #DEFINE_ALIAS from .layer.activation import SELU #DEFINE_ALIAS from .layer.activation import LeakyReLU #DEFINE_ALIAS from .layer.activation import Sigmoid #DEFINE_ALIAS from .layer.activation import LogSigmoid -# from .layer.activation import Softmax #DEFINE_ALIAS +from .layer.activation import Softmax #DEFINE_ALIAS from .layer.activation import Softplus #DEFINE_ALIAS from .layer.activation import Softshrink #DEFINE_ALIAS from .layer.activation import Softsign #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index ff2b1edf6723b..7c903ccf53c58 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -30,13 +30,14 @@ from .activation import erf #DEFINE_ALIAS from .activation import gelu #DEFINE_ALIAS from .activation import hardshrink #DEFINE_ALIAS +from .activation import hardtanh #DEFINE_ALIAS from .activation import hard_sigmoid #DEFINE_ALIAS from .activation import hard_swish #DEFINE_ALIAS from .activation import hsigmoid #DEFINE_ALIAS from .activation import leaky_relu #DEFINE_ALIAS from .activation import logsigmoid #DEFINE_ALIAS from .activation import maxout #DEFINE_ALIAS -# from .activation import prelu #DEFINE_ALIAS +from .activation import prelu #DEFINE_ALIAS from .activation import relu #DEFINE_ALIAS from .activation import relu6 #DEFINE_ALIAS from .activation import selu #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 16a86ce2e8cb6..619492a1b2bb1 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -30,13 +30,14 @@ 'erf', 'gelu', 'hardshrink', + 'hardtanh', 'hard_sigmoid', 'hard_swish', 'hsigmoid', 'leaky_relu', 'logsigmoid', 'maxout', - # 'prelu', + 'prelu', 'relu', 'relu6', 'selu', @@ -49,7 +50,7 @@ 'swish', 'tanhshrink', 'thresholded_relu', - 'log_softmax' + 'log_softmax', ] import warnings @@ -64,7 +65,7 @@ def elu(x, alpha=1.0, name=None): """ elu activation. - .. math:: + .. math:: elu(x) = max(0, x) + min(0, \\alpha * (e^{x}-1)) @@ -80,16 +81,16 @@ def elu(x, alpha=1.0, name=None): Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) - out = F.elu(x, alpha=0.2) - # [[-0.12642411 6. ] - # [ 1. 15.6 ]] + x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) + out = F.elu(x, alpha=0.2) + # [[-0.12642411 6. ] + # [ 1. 15.6 ]] """ if in_dygraph_mode(): @@ -111,10 +112,15 @@ def gelu(x, approximate=False, name=None): gelu activation. if approximate is True - .. math:: + + .. math:: + gelu(x) = 0.5 * x * (1 + tanh(\\sqrt{\\frac{2}{\\pi}} * (x + 0.044715x^{3}))) + else - .. math:: + + .. math:: + gelu(x) = 0.5 * x * (1 + erf(\\frac{x}{\\sqrt{2}})) Parameters: @@ -129,23 +135,15 @@ def gelu(x, approximate=False, name=None): Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np - - paddle.disable_static() - - data = np.random.randn(2, 3).astype("float32") - x = paddle.to_tensor(data) + import paddle + import paddle.nn.functional as F + import numpy as np - out = F.gelu(x) + paddle.disable_static() - data - # array([[ 0.87165993, -1.0541513 , -0.37214822], - # [ 0.15647964, 0.32496083, 0.33045998]], dtype=float32) - out - # array([[ 0.70456535, -0.15380788, -0.13207214], - # [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32) + x = paddle.to_tensor(np.array([[-1, 0.5],[1, 1.5]])) + out1 = F.gelu(x) # [-0.158655 0.345731 0.841345 1.39979] + out2 = F.gelu(x, True) # [-0.158808 0.345714 0.841192 1.39957] """ if in_dygraph_mode(): @@ -187,17 +185,16 @@ def hardshrink(x, threshold=0.5, name=None): A Tensor with the same data type and shape as ``x`` . Examples: - .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_variable(np.array([-1, 0.3, 2.5])) - out = F.hardshrink(x) # [-1., 0., 2.5] + x = paddle.to_tensor(np.array([-1, 0.3, 2.5])) + out = F.hardshrink(x) # [-1., 0., 2.5] """ if in_dygraph_mode(): @@ -215,6 +212,58 @@ def hardshrink(x, threshold=0.5, name=None): return out +def hardtanh(x, min=-1.0, max=1.0, name=None): + """ + hardtanh activation + + .. math:: + + hardtanh(x)= \\begin{cases} + max, \\text{if } x > max \\\\ + min, \\text{if } x < min \\\\ + x, \\text{otherwise} + \\end{cases} + + Args: + x (Tensor): The input Tensor with data type float32, float64. + min (float, optional): The minimum value of the linear region range. Default is -1. + max (float, optional): The maximum value of the linear region range. Default is 1. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-1.5, 0.3, 2.5])) + out = F.hardtanh(x) # [-1., 0.3, 1.] + """ + + if in_dygraph_mode(): + return core.ops.brelu(x, 't_min', min, 't_max', max) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'hardtanh') + + helper = LayerHelper('hardtanh', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='brelu', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'t_min': min, + 't_max': max}) + return out + + def hsigmoid(input, label, weight, @@ -272,7 +321,6 @@ def hsigmoid(input, Variable: A tensor with the cost of hierarchical sigmoid, its shape is [N, 1] and data type is the same as :attr:`input`. Examples: - .. code-block:: python from paddle import fluid, nn @@ -338,11 +386,86 @@ def hsigmoid(input, return out +def prelu(x, weight, name=None): + """ + prelu activation. + + .. math:: + + prelu(x) = max(0, x) + weight * min(0, x) + + Parameters: + x (Tensor): The input Tensor with data type float32, float64. + weight (Tensor): The learnable parameter with data type same as ``x``. + The weight shape is [1] or [in], where `in` is the input channel of ``x``. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + data = np.array([[[[-2.0, 3.0, -4.0, 5.0], + [ 3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[ 1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [ 6.0, 7.0, 8.0, 9.0]]]], 'float32') + x = paddle.to_tensor(data) + w = paddle.to_tensor(np.array([0.25]).astype('float32')) + out = F.prelu(x, w) + # [[[[-0.5 , 3. , -1. , 5. ], + # [ 3. , -1. , 5. , -1.5 ], + # [-1.75, -2. , 8. , 9. ]], + # [[ 1. , -0.5 , -0.75, 4. ], + # [-1.25, 6. , 7. , -2. ], + # [ 6. , 7. , 8. , 9. ]]]] + """ + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu') + check_variable_and_dtype(weight, 'weight', + ['float16', 'float32', 'float64'], 'prelu') + + helper = LayerHelper('prelu', **locals()) + assert len(weight.shape + ) == 1, "The dim count of weight shape should be 1 in prelu()." + + # NOTE(): The input of this API should be ``N,C,...`` format, + # which means x.shape[0] is batch_size and x.shape[0] is channel. + mode = 'all' + if weight.shape[0] > 1: + assert len( + x.shape + ) > 1, "The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]." + assert weight.shape[0] == x.shape[ + 1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]." + mode = 'channel' + + if in_dygraph_mode(): + return core.ops.prelu(x, weight, 'mode', mode) + + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type="prelu", + inputs={"X": x, + "Alpha": weight}, + outputs={"Out": out}, + attrs={"mode": mode}) + return out + + def relu(x, name=None): """ - ReLU Activation. + relu activation. - .. math: + .. math:: out = max(x, 0) @@ -357,14 +480,14 @@ def relu(x, name=None): Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32')) - out = F.relu(x) # [0., 0., 1.] + x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32')) + out = F.relu(x) # [0., 0., 1.] """ if in_dygraph_mode(): @@ -381,9 +504,9 @@ def logsigmoid(x, name=None): """ logsigmoid activation. - .. math: + .. math:: - logsigmoid(x) = \log \frac{1}{1 + e^{-x}} + logsigmoid(x) = log \\frac{1}{1 + e^{-x}} Parameters: x (Tensor): The input Tensor with data type float32, float64. @@ -396,14 +519,14 @@ def logsigmoid(x, name=None): Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0])) - out = F.logsigmoid(x) # [0.7310586, 0.880797, 0.95257413, 0.98201376] + x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0])) + out = F.logsigmoid(x) # [-0.313262 -0.126928 -0.0485874 -0.0181499] """ if in_dygraph_mode(): @@ -514,7 +637,7 @@ def selu(x, return out -def softmax(x, axis=-1, name=None): +def softmax(x, axis=-1, dtype=None, name=None): """ This operator implements the softmax layer. The calculation process is as follows: @@ -541,7 +664,7 @@ def softmax(x, axis=-1, name=None): .. math:: - out[i, j] = \\frac{\exp(x[i, j])}{\sum_j(exp(x[i, j])} + softmax[i, j] = \\frac{\\exp(x[i, j])}{\\sum_j(exp(x[i, j])} Example: @@ -590,44 +713,89 @@ def softmax(x, axis=-1, name=None): [0.26762315, 0.26762315, 0.26762315, 0.26762315], [0.72747516, 0.72747516, 0.72747516, 0.72747516]]] - Args: - x (Tensor): The input multi-dimension Tensor with data type float32, float64. - axis (int, optional): The axis along which to perform softmax calculations. - It should be in range [-D, D), where D is the dimensions of ``x`` . - When ``axis`` < 0, it works the same way as :math:`axis + D` . - Default is -1. + Parameters: + x (Tensor): The input Tensor with data type float32, float64. + axis (int, optional): The axis along which to perform log_softmax + calculations. It should be in range [-D, D), where D is the + dimensions of ``x`` . If ``axis`` < 0, it works the same way as + :math:`axis + D` . Default is -1. + dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data + type of the output tensor. If dtype is specified, ``x`` is casted + to ``dtype`` before the operation is performed. This is useful for + preventing data type overflows. Supported dtype: float32, float64. + If ``dtype`` is None, the output Tensor has the same dtype as x. + Default is None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - A Tensor with the same data type and shape as ``x`` . + A Tensor with the same shape and data type (use ``dtype`` if it is + specified) as x. Examples: - .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = np.array([[[2.0, 3.0, 4.0, 5.0], - [3.0, 4.0, 5.0, 6.0], - [7.0, 8.0, 8.0, 9.0]], - [[1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [6.0, 7.0, 8.0, 9.0]]], 'float32') - x = paddle.to_tensor(x) - out = F.softmax(x) - # [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426], - # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], - # [0.07232949, 0.19661193, 0.19661193, 0.53444665]], - # [[0.0320586 , 0.08714432, 0.23688282, 0.64391426], - # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], - # [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]] + x = np.array([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]], 'float32') + x = paddle.to_tensor(x) + out1 = F.softmax(x) + out2 = F.softmax(x, dtype='float64') + # out1's data type is float32; out2's data type is float64 + # out1 and out2's value is as follows: + # [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.07232949, 0.19661193, 0.19661193, 0.53444665]], + # [[0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]] """ - return paddle.fluid.layers.softmax(input=x, axis=axis, name=name) + + if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)): + dtype = convert_np_dtype_to_dtype_(dtype) + use_cudnn = True if axis is -1 else False + + if in_dygraph_mode(): + outs_cast = x if dtype is None \ + else core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) + return core.ops.softmax(outs_cast, 'axis', axis, 'use_cudnn', use_cudnn) + + if dtype is None: + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'softmax') + else: + check_dtype(dtype, 'dtype', ['float32', 'float64'], 'softmax', + 'If dtype is not None, it only support float32 or float64.') + + helper = LayerHelper("softmax", **locals()) + outs_cast = x + if dtype is not None: + outs_cast = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='cast', + inputs={'X': x}, + outputs={'Out': outs_cast}, + attrs={'in_dtype': x.dtype, + 'out_dtype': dtype}) + + outs_softmax = helper.create_variable_for_type_inference(outs_cast.dtype) + helper.append_op( + type='softmax', + inputs={'X': outs_cast}, + outputs={'Out': outs_softmax}, + attrs={'axis': axis, + 'use_cudnn': use_cudnn}) + + return outs_softmax def softplus(x, beta=1, threshold=20, name=None): @@ -820,7 +988,7 @@ def log_softmax(x, axis=-1, dtype=None, name=None): .. math:: Out[i, j] = log(softmax(x)) - = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) + = log(\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) Parameters: x (Tensor): The input Tensor with data type float32, float64. @@ -844,33 +1012,31 @@ def log_softmax(x, axis=-1, dtype=None, name=None): Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - paddle.disable_static() + paddle.disable_static() + + x = np.array([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]], 'float32') + x = paddle.to_tensor(x) + out1 = F.log_softmax(x) + out2 = F.log_softmax(x, dtype='float64') + # out1's data type is float32; out2's data type is float64 + # out1 and out2's value is as follows: + # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] + # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] + # [-16.313261 -17.313261 -1.3132617 -0.31326184]] + # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] + # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] + # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] + """ - x = np.array([[[-2.0, 3.0, -4.0, 5.0], - [3.0, -4.0, 5.0, -6.0], - [-7.0, -8.0, 8.0, 9.0]], - [[1.0, -2.0, -3.0, 4.0], - [-5.0, 6.0, 7.0, -8.0], - [6.0, 7.0, 8.0, 9.0]]], 'float32') - x = paddle.to_tensor(x) - out1 = F.log_softmax(x) - out2 = F.log_softmax(x, dtype='float64') - # out1's data type is float32; out2's data type is float64 - # out1 and out2's value is as follows: - # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] - # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] - # [-16.313261 -17.313261 -1.3132617 -0.31326184]] - # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] - # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] - # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] - """ - - if axis is None: - axis = -1 if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)): dtype = convert_np_dtype_to_dtype_(dtype) diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index bb294467bb3dd..bfa2e2f25d873 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -18,25 +18,28 @@ 'ELU', 'GELU', 'Hardshrink', - # 'PReLU', + 'Hardtanh', + 'PReLU', 'ReLU', 'ReLU6', 'SELU', 'LeakyReLU', 'Sigmoid', - # 'Softmax', + 'Softmax', 'Softplus', 'Softshrink', 'Softsign', 'Tanhshrink', 'LogSigmoid', 'LogSoftmax', - 'HSigmoid' + 'HSigmoid', ] from ...fluid.dygraph import layers from ...fluid import core from ...fluid.framework import in_dygraph_mode +from ...fluid.param_attr import ParamAttr +from ...fluid.initializer import Constant from .. import functional as F @@ -44,7 +47,7 @@ class ELU(layers.Layer): """ ELU Activation. - .. math:: + .. math:: ELU(x) = max(0, x) + min(0, \\alpha * (e^{x}-1)) @@ -60,16 +63,16 @@ class ELU(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) - m = paddle.nn.ELU(0.2) - out = m(x) - # [[-0.12642411 6. ] - # [ 1. 15.6 ]] + x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) + m = paddle.nn.ELU(0.2) + out = m(x) + # [[-0.12642411 6. ] + # [ 1. 15.6 ]] """ def __init__(self, alpha=1.0, name=None): @@ -87,13 +90,13 @@ class GELU(layers.Layer): If approximate is True - .. math:: + .. math:: GELU(x) = 0.5 * x * (1 + tanh(\\sqrt{\\frac{2}{\\pi}} * (x + 0.044715x^{3}))) else - .. math:: + .. math:: GELU(x) = 0.5 * x * (1 + erf(\\frac{x}{\\sqrt{2}})) @@ -109,23 +112,18 @@ class GELU(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np - - paddle.disable_static() + import paddle + import numpy as np - data = np.random.randn(2, 3).astype("float32") - x = paddle.to_tensor(data) + paddle.disable_static() - m = paddle.nn.GELU() - out = m(x) + x = paddle.to_tensor(np.array([[-1, 0.5],[1, 1.5]])) + + m = paddle.nn.GELU() + out = m(x) # [-0.158655 0.345731 0.841345 1.39979] - data - # array([[ 0.87165993, -1.0541513 , -0.37214822], - # [ 0.15647964, 0.32496083, 0.33045998]], dtype=float32) - out - # array([[ 0.70456535, -0.15380788, -0.13207214], - # [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32) + m = paddle.nn.GELU(True) + out = m(x) # [-0.158808 0.345714 0.841192 1.39957] """ def __init__(self, approximate=False, name=None): @@ -170,7 +168,7 @@ class Hardshrink(layers.Layer): paddle.disable_static() - x = paddle.to_variable(np.array([-1, 0.3, 2.5])) + x = paddle.to_tensor(np.array([-1, 0.3, 2.5])) m = paddle.nn.Hardshrink() out = m(x) # [-1., 0., 2.5] """ @@ -184,6 +182,51 @@ def forward(self, x): return F.hardshrink(x, self._threshold, self._name) +class Hardtanh(layers.Layer): + """ + Hardtanh Activation + + .. math:: + + Hardtanh(x)= \\begin{cases} + max, \\text{if } x > max \\\\ + min, \\text{if } x < min \\\\ + x, \\text{otherwise} + \\end{cases} + + Parameters: + min (float, optional): The value of min for Hardtanh. Default is -1. + max (float, optional): The value of max for Hardtanh. Default is 1. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([-1.5, 0.3, 2.5])) + m = paddle.nn.Hardtanh() + out = m(x) # # [-1., 0.3, 1.] + """ + + def __init__(self, min=-1.0, max=1.0, name=None): + super(Hardtanh, self).__init__() + self._min = min + self._max = max + self._name = name + + def forward(self, x): + return F.hardtanh(x, self._min, self._max, self._name) + + class HSigmoid(layers.Layer): """ :alias_main: paddle.nn.HSigmoid @@ -320,11 +363,78 @@ def forward(self, input, label, path_table=None, path_code=None): return out +class PReLU(layers.Layer): + """ + PReLU Activation. + + .. math:: + + PReLU(x) = max(0, x) + weight * min(0, x) + + Parameters: + num_parameters (int, optional): Number of `weight` to learn. The supported values are: + 1 - a single parameter `alpha` is used for all input channels; + Number of channels - a seperate `alpha` is used for each input channel. + Default is 1. + init (float, optional): Init value of learnable `weight`. Default is 0.25. + weight_attr(ParamAttr, optional): The parameter attribute for the learnable `weight`. + Default is None. For more information, please refer to :ref:`api_fluid_ParamAttr`. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + data = np.array([[[[-2.0, 3.0, -4.0, 5.0], + [ 3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[ 1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [ 6.0, 7.0, 8.0, 9.0]]]], 'float32') + x = paddle.to_tensor(data) + m = paddle.nn.PReLU(1, 0.25) + out = m(x) + # [[[[-0.5 , 3. , -1. , 5. ], + # [ 3. , -1. , 5. , -1.5 ], + # [-1.75, -2. , 8. , 9. ]], + # [[ 1. , -0.5 , -0.75, 4. ], + # [-1.25, 6. , 7. , -2. ], + # [ 6. , 7. , 8. , 9. ]]]] + """ + + def __init__(self, num_parameters=1, init=0.25, weight_attr=None, + name=None): + super(PReLU, self).__init__() + self._num_parameters = num_parameters + self._init = init + self._weight_attr = weight_attr + self._name = name + + self._weight = self.create_parameter( + attr=self._weight_attr, + shape=[num_parameters], + dtype='float32', + is_bias=False, + default_initializer=Constant(init)) + + def forward(self, x): + return F.prelu(x, self._weight) + + class ReLU(layers.Layer): """ ReLU Activation. - .. math: + .. math:: ReLU(x) = max(x, 0) @@ -339,14 +449,14 @@ class ReLU(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32')) - m = paddle.nn.ReLU() - out = m(x) # [0., 0., 1.] + x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32')) + m = paddle.nn.ReLU() + out = m(x) # [0., 0., 1.] """ def __init__(self, name=None): @@ -488,7 +598,7 @@ class Sigmoid(layers.Layer): .. math:: - output = \\frac{1}{1 + e^{-x}} + Sigmoid(x) = \frac{1}{1 + e^{-x}} Parameters: name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -509,7 +619,7 @@ class Sigmoid(layers.Layer): paddle.disable_static() input_data = np.array([1.0, 2.0, 3.0, 4.0]).astype('float32') m = paddle.nn.Sigmoid() - x = paddle.to_variable(input_data) + x = paddle.to_tensor(input_data) output = m(x) print(output.numpy()) # [0.7310586, 0.880797, 0.95257413, 0.98201376] """ @@ -687,9 +797,9 @@ class LogSigmoid(layers.Layer): """ LogSigmoid Activation. - .. math: + .. math:: - LogSigmoid(x) = \log \frac{1}{1 + e^{-x}} + LogSigmoid(x) = log \\frac{1}{1 + e^{-x}} Parameters: x (Tensor): The input Tensor with data type float32, or float64. @@ -703,14 +813,14 @@ class LogSigmoid(layers.Layer): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0])) - m = paddle.nn.LogSigmoid() - out = m(x) # [0.7310586, 0.880797, 0.95257413, 0.98201376] + x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0])) + m = paddle.nn.LogSigmoid() + out = m(x) # [-0.313262 -0.126928 -0.0485874 -0.0181499] """ def __init__(self, name=None): @@ -721,6 +831,137 @@ def forward(self, x): return F.logsigmoid(x, self._name) +class Softmax(layers.Layer): + """ + Softmax Activation. + + This operator implements the softmax layer. The calculation process is as follows: + + 1. The dimension :attr:`axis` of ``x`` will be permuted to the last. + + 2. Then ``x`` will be logically flattened to a 2-D matrix. The matrix's second + dimension(row length) is the same as the dimension :attr:`axis` of ``x``, + and the first dimension(column length) is the product of all other dimensions + of ``x``. For each row of the matrix, the softmax operator squashes the + K-dimensional(K is the width of the matrix, which is also the size of ``x``'s + dimension :attr:`axis`) vector of arbitrary real values to a K-dimensional + vector of real values in the range [0, 1] that add up to 1. + + 3. After the softmax operation is completed, the inverse operations of steps 1 and 2 + are performed to restore the two-dimensional matrix to the same dimension as the ``x`` . + + It computes the exponential of the given dimension and the sum of exponential + values of all the other dimensions in the K-dimensional vector input. + Then the ratio of the exponential of the given dimension and the sum of + exponential values of all the other dimensions is the output of the softmax + operator. + + For each row :math:`i` and each column :math:`j` in the matrix, we have: + + .. math:: + + Softmax[i, j] = \\frac{\\exp(x[i, j])}{\\sum_j(exp(x[i, j])} + + Example: + + .. code-block:: text + + Case 1: + Input: + x.shape = [2, 3, 4] + x.data = [[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]] + + Attrs: + axis = -1 + + Output: + out.shape = [2, 3, 4] + out.data = [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426], + [0.0320586 , 0.08714432, 0.23688282, 0.64391426], + [0.07232949, 0.19661193, 0.19661193, 0.53444665]], + [[0.0320586 , 0.08714432, 0.23688282, 0.64391426], + [0.0320586 , 0.08714432, 0.23688282, 0.64391426], + [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]] + + Case 2: + Input: + x.shape = [2, 3, 4] + x.data = [[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]] + Attrs: + axis = 1 + + Output: + out.shape = [2, 3, 4] + out.data = [[[0.00657326, 0.00657326, 0.01714783, 0.01714783], + [0.01786798, 0.01786798, 0.04661262, 0.04661262], + [0.97555875, 0.97555875, 0.93623955, 0.93623955]], + [[0.00490169, 0.00490169, 0.00490169, 0.00490169], + [0.26762315, 0.26762315, 0.26762315, 0.26762315], + [0.72747516, 0.72747516, 0.72747516, 0.72747516]]] + + Parameters: + axis (int, optional): The axis along which to perform log_softmax + calculations. It should be in range [-D, D), where D is the + dimensions of ``x`` . If ``axis`` < 0, it works the same way as + :math:`axis + D` . Default is -1. + dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data + type of the output tensor. If dtype is specified, ``x`` is casted + to ``dtype`` before the operation is performed. This is useful for + preventing data type overflows. Supported dtype: float32, float64. + If ``dtype`` is None, the output Tensor has the same dtype as x. + Default is None. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = np.array([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]], 'float32') + x = paddle.to_tensor(x) + m = paddle.nn.Softmax() + out = m(x) + # [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.07232949, 0.19661193, 0.19661193, 0.53444665]], + # [[0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], + # [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]] + """ + + def __init__(self, axis=-1, name=None): + super(Softmax, self).__init__() + self._axis = axis + self._dtype = None + self._name = name + + def forward(self, x): + return F.softmax(x, self._axis, self._dtype, self._name) + + class LogSoftmax(layers.Layer): """ This operator implements the log_softmax layer. The calculation process is as follows: @@ -728,7 +969,7 @@ class LogSoftmax(layers.Layer): .. math:: Out[i, j] = log(softmax(x)) - = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) + = log(\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) Parameters: axis (int, optional): The axis along which to perform log_softmax