Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

leaky_relu and LeakyReLU: alpha->negative_slope #26216

Merged
merged 16 commits into from
Aug 22, 2020
8 changes: 4 additions & 4 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,8 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
}
};

// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
// leaky_relu Grad: dx=dy if x>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if x>=0 else alpha * ddx
template <typename T>
class LeakyReluDoubleGradMaker
: public ::paddle::framework::SingleGradOpMaker<T> {
Expand All @@ -770,8 +770,8 @@ class LeakyReluDoubleGradMaker
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("leaky_relu_grad_grad");
// input1: Out
op->SetInput("Out", this->Input("Out"));
// input1: X
op->SetInput("X", this->Input("X"));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(this->Attrs());
Expand Down
28 changes: 16 additions & 12 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,11 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> {

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
if (alpha < 1.f) {
out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
} else {
out.device(d) = x.cwiseMin(static_cast<T>(alpha) * x);
}
}
};

Expand All @@ -1084,12 +1088,12 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 =
static_cast<T>(alpha) * (out <= static_cast<T>(0)).template cast<T>();
auto temp2 = (out > static_cast<T>(0)).template cast<T>();
static_cast<T>(alpha) * (x < static_cast<T>(0)).template cast<T>();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
Expand Down Expand Up @@ -1437,18 +1441,18 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "LeakyReluGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "LeakyReluGradGrad"));
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad"));
ddout.device(*d) = ddx *
((out > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) *
(out <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
ddout.device(*d) =
ddx *
((x > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * (x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,20 @@ static void InitRandom(framework::Tensor *tensor,

template <typename T>
struct LeakyReluGradGradEachElementFunctor {
LeakyReluGradGradEachElementFunctor(const T *ddx, const T *out, T alpha,
LeakyReluGradGradEachElementFunctor(const T *ddx, const T *x, T alpha,
T *ddout)
: ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {}
: ddx_(ddx), x_(x), alpha_(alpha), ddout_(ddout) {}

HOSTDEVICE void operator()(int idx) {
if (out_[idx] > 0) {
if (x_[idx] >= 0) {
ddout_[idx] = ddx_[idx];
} else {
ddout_[idx] = ddx_[idx] * alpha_;
}
}

const T *ddx_;
const T *out_;
const T *x_;
T alpha_;
T *ddout_;
};
Expand All @@ -66,13 +66,13 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim,
LeakyReluGradGradFunctor<T> functor;
functor.alpha = alpha;
auto &dev_ctx = *platform::DeviceContextPool::Instance().Get(place);
framework::Tensor *x = nullptr;
framework::Tensor *out = nullptr;
framework::Tensor *dout = nullptr;
framework::Tensor *dx = nullptr;

framework::Tensor out;
out.Resize(dim);
InitRandom<T>(&out, place);
framework::Tensor x;
x.Resize(dim);
InitRandom<T>(&x, place);

framework::Tensor ddx;
ddx.Resize(dim);
Expand All @@ -85,22 +85,22 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim,
framework::Tensor ddout_actual;
ddout_actual.mutable_data<T>(dim, place);
LeakyReluGradGradEachElementFunctor<T> actual_functor(
ddx.data<T>(), out.data<T>(), static_cast<T>(alpha),
ddx.data<T>(), x.data<T>(), static_cast<T>(alpha),
ddout_actual.data<T>());

int64_t limit = out.numel();
int64_t limit = x.numel();

#ifdef __NVCC__
if (platform::is_gpu_place(place)) {
auto &cuda_dev_ctx = dynamic_cast<platform::CUDADeviceContext &>(dev_ctx);
functor(cuda_dev_ctx, x, &out, &ddx, &ddout, dout, dx);
functor(cuda_dev_ctx, &x, out, &ddx, &ddout, dout, dx);
platform::ForRange<platform::CUDADeviceContext> for_range(cuda_dev_ctx,
limit);
for_range(actual_functor);
} else {
#endif
auto &cpu_dev_ctx = dynamic_cast<platform::CPUDeviceContext &>(dev_ctx);
functor(cpu_dev_ctx, x, &out, &ddx, &ddout, dout, dx);
functor(cpu_dev_ctx, &x, out, &ddx, &ddout, dout, dx);
platform::ForRange<platform::CPUDeviceContext> for_range(cpu_dev_ctx,
limit);
for_range(actual_functor);
Expand Down
19 changes: 2 additions & 17 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9708,13 +9708,10 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
return out


@deprecated(since="2.0.0", update_to="paddle.nn.functional.leaky_relu")
@templatedoc()
def leaky_relu(x, alpha=0.02, name=None):
"""
:alias_main: paddle.nn.functional.leaky_relu
:alias: paddle.nn.functional.leaky_relu,paddle.nn.functional.activation.leaky_relu
:old_api: paddle.fluid.layers.leaky_relu

${comment}
Args:
x(${x_type}): ${x_comment}
Expand Down Expand Up @@ -9743,19 +9740,7 @@ def leaky_relu(x, alpha=0.02, name=None):
res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res])
print(res_val) # [[-0.1, 2], [3, -0.4]]
"""
if in_dygraph_mode():
return core.ops.leaky_relu(x, 'alpha', alpha)

check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'leaky_relu')

inputs = {'X': [x]}
attrs = {'alpha': alpha}
helper = LayerHelper('leaky_relu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='leaky_relu', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out
return paddle.nn.functional.leaky_relu(x, alpha, name)


def soft_relu(x, threshold=40.0, name=None):
Expand Down
96 changes: 84 additions & 12 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,37 +737,109 @@ def test_errors(self):
F.relu(x_fp16)


def ref_leaky_relu(x, alpha=0.01):
out = np.copy(x)
out[out < 0] *= alpha
return out


class TestLeakyRelu(TestActivation):
def get_alpha(self):
return 0.02

def setUp(self):
self.op_type = "leaky_relu"
self.init_dtype()
alpha = self.get_alpha()

np.random.seed(10)
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
out = np.maximum(x, 0.02 * x)
x[np.abs(x) < 0.005] = 0.05
out = ref_leaky_relu(x, alpha)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'alpha': alpha}

def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')


class TestLeakyReluOpError(unittest.TestCase):
class TestLeakyReluAlpha1(TestLeakyRelu):
def get_alpha(self):
return 2


class TestLeakyReluAlpha2(TestLeakyRelu):
def get_alpha(self):
return -0.01


class TestLeakyReluAlpha3(TestLeakyRelu):
def get_alpha(self):
return -2.0


class TestLeakyReluAPI(unittest.TestCase):
# test paddle.nn.LeakyReLU, paddle.nn.functional.leaky_relu,
# fluid.layers.leaky_relu
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12])
out1 = F.leaky_relu(x)
m = paddle.nn.LeakyReLU()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_leaky_relu(self.x_np)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_variable(self.x_np)
out1 = F.leaky_relu(x)
m = paddle.nn.LeakyReLU()
out2 = m(x)
out_ref = ref_leaky_relu(self.x_np)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)

out1 = F.leaky_relu(x, 0.6)
m = paddle.nn.LeakyReLU(0.6)
out2 = m(x)
out_ref = ref_leaky_relu(self.x_np, 0.6)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()

def test_fluid_api(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data('X', [10, 12])
out = fluid.layers.leaky_relu(x, 0.01)
exe = fluid.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = ref_leaky_relu(self.x_np)
self.assertEqual(np.allclose(out_ref, res[0]), True)

def test_errors(self):
with program_guard(Program()):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.leaky_relu, 1)
self.assertRaises(TypeError, F.leaky_relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.leaky_relu, x_int32)
# support the input dtype is float32
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float32')
fluid.layers.leaky_relu(x_fp16)
x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.leaky_relu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16')
F.leaky_relu(x_fp16)


def gelu(x, approximate):
Expand Down
22 changes: 0 additions & 22 deletions python/paddle/fluid/tests/unittests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,21 +316,6 @@ def test_relu(self):

self.assertTrue(np.allclose(static_ret, dy_ret_value))

def test_leakyrelu(self):
inputs = np.random.uniform(-1, 1, (10, 10)).astype('float32')
with self.static_graph():
t = layers.data(name='t', shape=[10, 10], dtype='float32')
ret = layers.leaky_relu(t, alpha=0.01)
static_ret = self.get_static_graph_result(
feed={'t': inputs}, fetch_list=[ret])[0]

with self.dynamic_graph():
lrelu = paddle.nn.LeakyReLU(alpha=0.01)
dy_ret = lrelu(base.to_variable(inputs))
dy_ret_value = dy_ret.numpy()

self.assertTrue(np.allclose(static_ret, dy_ret_value))

def test_pad2d(self):
with self.static_graph():
t = layers.data(name='t', shape=[-1, 3, 5, 5], dtype='float32')
Expand Down Expand Up @@ -2678,13 +2663,6 @@ def make_brelu(self):
out = layers.brelu(input, t_min=1.0, t_max=20.0, name='brelu')
return (out)

def make_leaky_relu(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = self._get_data(name="input", shape=[16], dtype="float32")
out = layers.leaky_relu(input, alpha=0.1, name='leaky_relu')
return (out)

def make_soft_relu(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
Expand Down
Loading