Skip to content

Commit

Permalink
Merge #765
Browse files Browse the repository at this point in the history
765: fix for #758 r=DhairyaLGandhi a=racinmat

Added few tests, defined adjoint for scalar selu and elu.
Fixes #578 
I hope this is right, I'm still not sure if I fully got all the tricks around adjoints.

Co-authored-by: Matěj Račinský <matej.racinsky@avast.com>
  • Loading branch information
bors[bot] and racinmat committed Aug 18, 2020
2 parents 956575e + 1b27b9c commit 09ba74e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
10 changes: 9 additions & 1 deletion src/lib/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,21 @@ end
function dselu(x)
λ = oftype(x/1, 1.0507009873554804934193349852946)
α = oftype(x/1, 1.6732632423543772848170429916717)
λ * ifelse(x > 0, 1, α * exp(x))
λ * ifelse(x > 0, one(x), α * exp(x))
end

@adjoint selu(x::Numeric) = selu(x), Δ -> (dselu(x) * Δ,)
@adjoint function Base.Broadcast.broadcasted(::typeof(selu), x::Numeric)
selu.(x), Δ -> (nothing, dselu.(x) .* Δ)
end

delu(x, α) = ifelse(x 0, one(x), α * exp(x))

@adjoint elu(x::Numeric, α::Numeric) = elu(x, α), Δ -> (delu.(x, α) .* Δ, nothing)
@adjoint function Base.Broadcast.broadcasted(::typeof(elu), x::Numeric, α::Numeric)
elu.(x, α), Δ -> (nothing, delu.(x, α) .* Δ, nothing)
end

@adjoint function σ(x::Real)
y = σ(x)
return y, Δ ->* y * (1 - y),)
Expand Down
22 changes: 20 additions & 2 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ end
@test gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2)
@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2)

# tests for https://github.com/FluxML/Zygote.jl/issues/758
@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000]) == ([1.0507009873554805, 1.0507009873554805],)
@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,)
@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],)
@test gradient(x -> elu(x, 2), 1_000) == (1.,)
@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),)
@test gradcheck(x->sum(selu.(x)),[100., 1_000.])
@test gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.])
@test gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting
# numerical instability even for the linear part of such function, see:
# julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.])
# ([1.0506591796875, 1.0506591796875],)
# julia> gradient(x->sum(selu.(x)),[1_000., 10_000.])
# ([1.0507009873554805, 1.0507009873554805],)
@test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.])

@test gradtest((x, W, b) -> tanh.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> tanh.(W*x .+ b), (5,3), (2,5), 2)
Expand Down Expand Up @@ -1453,7 +1471,7 @@ end

x = randn(Float64,16,16)
@test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2}
@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}
@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}

x = randn(Float32,16)
P = plan_fft(x)
Expand All @@ -1462,7 +1480,7 @@ end
@test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1}

x = randn(Float32,16,16)
@test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2}
@test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2}
@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2}
end

Expand Down

0 comments on commit 09ba74e

Please sign in to comment.