Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hardtanh prelu softmax, test=develop #26431

Merged
merged 12 commits into from
Aug 21, 2020
6 changes: 4 additions & 2 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,7 @@ def chunk_eval(input,
num_correct_chunks)


@deprecated(since="2.0.0", update_to="paddle.nn.functional.softmax")
def softmax(input, use_cudnn=False, name=None, axis=-1):
"""
This operator implements the softmax layer. The calculation process is as follows:
Expand Down Expand Up @@ -8601,7 +8602,7 @@ def log(x, name=None):
return out


@templatedoc()
@deprecated(since="2.0.0", update_to="paddle.nn.functional.relu")
def relu(x, name=None):
"""
${comment}
Expand Down Expand Up @@ -9260,7 +9261,7 @@ def pad2d(input,
return out


@templatedoc()
@deprecated(since="2.0.0", update_to="paddle.nn.functional.elu")
def elu(x, alpha=1.0, name=None):
"""
:alias_main: paddle.nn.functional.elu
Expand Down Expand Up @@ -9576,6 +9577,7 @@ def swish(x, beta=1.0, name=None):
return out


@deprecated(since="2.0.0", update_to="paddle.nn.functional.prelu")
def prelu(x, mode, param_attr=None, name=None):
"""
:api_attr: Static Graph
Expand Down
57 changes: 57 additions & 0 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,63 @@ def test_errors(self):
F.hardshrink(x_fp16)


def ref_hardtanh(x, min=-1.0, max=1.0):
out = np.copy(x)
out[np.abs(x - min) < 0.005] = min + 0.02
out[np.abs(x - max) < 0.005] = max + 0.02
out = np.minimum(np.maximum(x, min), max)
return out


class TestHardtanhAPI(unittest.TestCase):
# test paddle.nn.Hardtanh, paddle.nn.functional.hardtanh
def setUp(self):
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12])
out1 = F.hardtanh(x)
m = paddle.nn.Hardtanh()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_hardtanh(self.x_np)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_variable(self.x_np)
out1 = F.hardtanh(x)
m = paddle.nn.Hardtanh()
out2 = m(x)
out_ref = ref_hardtanh(self.x_np)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)

out1 = F.hardtanh(x, -2.0, 2.0)
m = paddle.nn.Hardtanh(-2.0, 2.0)
out2 = m(x)
out_ref = ref_hardtanh(self.x_np, -2.0, 2.0)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()

def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.hardtanh, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.hardtanh, x_int32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16')
F.hardtanh(x_fp16)


def ref_softshrink(x, threshold=0.5):
out = np.copy(x)
out = (out < -threshold) * (out + threshold) + (out > threshold) * (
Expand Down
133 changes: 122 additions & 11 deletions python/paddle/fluid/tests/unittests/test_prelu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,134 @@
import numpy as np
import paddle.fluid as fluid
import six
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.nn.functional as F


def ref_prelu(x, weight):
x_t = x.copy()
weight = weight.reshape(1, -1, 1, 1)
neg_indices = x <= 0
assert x.shape == neg_indices.shape
x_t[neg_indices] = (x_t * weight)[neg_indices]
return (x_t, )


def ref_prelu_nn(x, num_parameters, init):
weight_np = np.full((num_parameters), init)
return ref_prelu(x, weight_np)


class TestPReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
class TestFunctionalPReluAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
self.x_np = np.random.uniform(-1., 1., [1, 2, 3, 4]).astype('float32')
self.weight_np_0 = np.random.randn(1).astype('float32')
self.weight_np_1 = np.random.randn(self.x_np.shape[1]).astype('float32')

def static_check(self, weight_np):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_np.shape, 'float32')
weight = paddle.data('Alpha', weight_np.shape, 'float32')
out = F.prelu(x, weight)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np,
'Alpha': weight_np},
fetch_list=[out])
out_ref = ref_prelu(self.x_np, weight_np)
self.assertEqual(np.allclose(out_ref, res[0]), True)

def dygraph_check(self, weight_np):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
weight = paddle.to_tensor(weight_np)
out = F.prelu(x, weight)
out_ref = ref_prelu(self.x_np, weight_np)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)
paddle.enable_static()

def test_static_api(self):
self.static_check(self.weight_np_0)
self.static_check(self.weight_np_1)

def test_dygraph_api(self):
self.dygraph_check(self.weight_np_0)
self.dygraph_check(self.weight_np_1)

def test_error(self):
with paddle.static.program_guard(paddle.static.Program()):
weight_fp32 = paddle.data(
name='weight_fp32', shape=[1], dtype='float32')
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.prelu, 0.1, 'all')
self.assertRaises(TypeError, F.prelu, x=1, weight=weight_fp32)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.prelu, x_int32, 'all')
# support the input dtype is float32
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float32')
fluid.layers.prelu(x_fp16, 'all')
x_int32 = paddle.data(name='x_int32', shape=[2, 3], dtype='int32')
self.assertRaises(TypeError, F.prelu, x=x_int32, weight=weight_fp32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[2, 3], dtype='float16')
F.prelu(x=x_fp16, weight=weight_fp32)


class TestNNPReluAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
self.x_np = np.ones([1, 2, 3, 4]).astype('float32')

def test_static_api(self):
startup_program = paddle.static.Program()
train_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
x = paddle.data(name='X', shape=self.x_np.shape, dtype='float32')
m = paddle.nn.PReLU()
out = m(x)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
res = exe.run(train_program,
feed={'X': self.x_np},
fetch_list=[out])
out_ref = ref_prelu_nn(self.x_np, 1, 0.25)
self.assertEqual(np.allclose(out_ref, res[0]), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)

x = paddle.to_tensor(self.x_np)
m = paddle.nn.PReLU()
out = m(x)
out_ref = ref_prelu_nn(self.x_np, 1, 0.25)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

x = paddle.to_tensor(self.x_np)
m = paddle.nn.PReLU(num_parameters=self.x_np.shape[1])
out = m(x)
out_ref = ref_prelu_nn(self.x_np, self.x_np.shape[1], 0.25)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

x = paddle.to_tensor(self.x_np)
m = paddle.nn.PReLU(init=0.5)
out = m(x)
out_ref = ref_prelu_nn(self.x_np, 1, 0.5)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

x = paddle.to_tensor(self.x_np)
m = paddle.nn.PReLU(weight_attr=fluid.ParamAttr(name="weight"))
out = m(x)
out_ref = ref_prelu_nn(self.x_np, 1, 0.25)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

x = paddle.to_tensor(self.x_np)
m = paddle.nn.PReLU(weight_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.5)))
out = m(x)
out_ref = ref_prelu_nn(self.x_np, 1, 0.5)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

paddle.enable_static()


class PReluTest(OpTest):
Expand Down
79 changes: 46 additions & 33 deletions python/paddle/fluid/tests/unittests/test_softmax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def stable_softmax(x):
return exps / np.sum(exps)


def ref_softmax(x, axis=None, dtype=None):
x_t = x.copy()
if dtype is not None:
x_t = x_t.astype(dtype)
if axis is None:
axis = -1
return np.apply_along_axis(stable_softmax, axis, x_t)


class TestSoftmaxOp(OpTest):
def get_x_shape(self):
return [10, 10]
Expand Down Expand Up @@ -93,20 +102,6 @@ def test_check_grad(self):
check_dygraph=(self.use_mkldnn == False))


class TestSoftmaxOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of softmax_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.softmax, x1)
# The input dtype of softmax_op must be float16, float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.softmax, x2)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16")
fluid.layers.softmax(x3)


class TestSoftmaxOp2(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
Expand Down Expand Up @@ -224,41 +219,59 @@ def get_x_shape(self):
return [2, 3, 4, 5]


class TestNnFunctionalSoftmaxApi(unittest.TestCase):
class TestSoftmaxAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
self.x_np = np.random.uniform(-1., 1., [2, 3, 4, 5]).astype('float32')
self.out_ref = np.apply_along_axis(stable_softmax, -1, self.x_np)

def test_api_static(self):
with program_guard(Program()):
def test_static_check(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_np.shape, 'float32')
out = F.softmax(x)
out1 = F.softmax(x)
m = paddle.nn.Softmax()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
self.assertEqual(np.allclose(self.out_ref, res[0]), True)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_softmax(self.x_np, axis=-1, dtype=None)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_api_imperative(self):
def test_dygraph_check(self):
paddle.disable_static(self.place)

x = paddle.to_variable(self.x_np)
out = F.softmax(x)
self.assertEqual(np.allclose(self.out_ref, out.numpy()), True)

out = F.softmax(x, axis=0)
out_ref = np.apply_along_axis(stable_softmax, 0, self.x_np)
x = paddle.to_tensor(self.x_np)
out1 = F.softmax(x)
m = paddle.nn.Softmax()
out2 = m(x)
out_ref = ref_softmax(self.x_np, axis=-1, dtype=None)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)

out1 = F.softmax(x, axis=0)
m = paddle.nn.Softmax(axis=0)
out2 = m(x)
out_ref = ref_softmax(self.x_np, axis=0, dtype=None)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)

out = F.softmax(x, dtype=np.float64)
out_ref = ref_softmax(self.x_np, axis=-1, dtype=np.float64)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)

paddle.enable_static()

def test_error(self):
with program_guard(Program(), Program()):
# The x should be variable and its dtype should be float32, float64.
self.assertRaises(TypeError, F.softmax, [1])

x = paddle.data(name='x', shape=[2, 3], dtype='int32')
self.assertRaises(TypeError, F.softmax, x)
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.softmax, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.data(name='x_int32', shape=[2, 3], dtype='int32')
self.assertRaises(TypeError, F.softmax, x_int32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[2, 3], dtype='float16')
F.softmax(x_fp16)


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@
from .layer.activation import ELU
from .layer.activation import GELU
from .layer.activation import Hardshrink
# from .layer.activation import PReLU #DEFINE_ALIAS
from .layer.activation import Hardtanh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一下 #DEFINE_ALIAS

from .layer.activation import PReLU
from .layer.activation import ReLU
from .layer.activation import ReLU6 #DEFINE_ALIAS
from .layer.activation import SELU #DEFINE_ALIAS
from .layer.activation import LeakyReLU #DEFINE_ALIAS
from .layer.activation import Sigmoid #DEFINE_ALIAS
from .layer.activation import LogSigmoid
# from .layer.activation import Softmax #DEFINE_ALIAS
from .layer.activation import Softmax #DEFINE_ALIAS
from .layer.activation import Softplus #DEFINE_ALIAS
from .layer.activation import Softshrink #DEFINE_ALIAS
from .layer.activation import Softsign #DEFINE_ALIAS
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
from .activation import erf #DEFINE_ALIAS
from .activation import gelu #DEFINE_ALIAS
from .activation import hardshrink #DEFINE_ALIAS
from .activation import hardtanh #DEFINE_ALIAS
from .activation import hard_sigmoid #DEFINE_ALIAS
from .activation import hard_swish #DEFINE_ALIAS
from .activation import hsigmoid #DEFINE_ALIAS
from .activation import leaky_relu #DEFINE_ALIAS
from .activation import logsigmoid #DEFINE_ALIAS
from .activation import maxout #DEFINE_ALIAS
# from .activation import prelu #DEFINE_ALIAS
from .activation import prelu #DEFINE_ALIAS
from .activation import relu #DEFINE_ALIAS
from .activation import relu6 #DEFINE_ALIAS
from .activation import selu #DEFINE_ALIAS
Expand Down
Loading