diff --git a/src/lib/nnlib.jl b/src/lib/nnlib.jl index edb81fb5e..c8039ed94 100644 --- a/src/lib/nnlib.jl +++ b/src/lib/nnlib.jl @@ -10,13 +10,21 @@ end function dselu(x) λ = oftype(x/1, 1.0507009873554804934193349852946) α = oftype(x/1, 1.6732632423543772848170429916717) - λ * ifelse(x > 0, 1, α * exp(x)) + λ * ifelse(x > 0, one(x), α * exp(x)) end +@adjoint selu(x::Numeric) = selu(x), Δ -> (dselu(x) * Δ,) @adjoint function Base.Broadcast.broadcasted(::typeof(selu), x::Numeric) selu.(x), Δ -> (nothing, dselu.(x) .* Δ) end +delu(x, α) = ifelse(x ≥ 0, one(x), α * exp(x)) + +@adjoint elu(x::Numeric, α::Numeric) = elu(x, α), Δ -> (delu.(x, α) .* Δ, nothing) +@adjoint function Base.Broadcast.broadcasted(::typeof(elu), x::Numeric, α::Numeric) + elu.(x, α), Δ -> (nothing, delu.(x, α) .* Δ, nothing) +end + @adjoint function σ(x::Real) y = σ(x) return y, Δ -> (Δ * y * (1 - y),) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index e120293d7..2a7415d26 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -93,6 +93,24 @@ end @test gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2) +@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2) +@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2) + +# tests for https://github.com/FluxML/Zygote.jl/issues/758 +@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000]) == ([1.0507009873554805, 1.0507009873554805],) +@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,) +@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],) +@test gradient(x -> elu(x, 2), 1_000) == (1.,) +@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),) +@test gradcheck(x->sum(selu.(x)),[100., 1_000.]) +@test gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.]) +@test gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting +# numerical instability even for the linear part of such function, see: +# julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.]) +# ([1.0506591796875, 1.0506591796875],) +# julia> gradient(x->sum(selu.(x)),[1_000., 10_000.]) +# ([1.0507009873554805, 1.0507009873554805],) +@test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.]) @test gradtest((x, W, b) -> tanh.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> tanh.(W*x .+ b), (5,3), (2,5), 2) @@ -1453,7 +1471,7 @@ end x = randn(Float64,16,16) @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} + @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} x = randn(Float32,16) P = plan_fft(x) @@ -1462,7 +1480,7 @@ end @test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1} x = randn(Float32,16,16) - @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} + @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2} end