From c84e870c553e4c815a32aeb4969029de17518976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20Ra=C4=8Dinsk=C3=BD?= Date: Thu, 13 Aug 2020 17:13:13 +0200 Subject: [PATCH 1/6] fixes for https://github.com/FluxML/Zygote.jl/issues/758 and tests for numeric stability --- src/lib/nnlib.jl | 11 ++++++++++- test/gradcheck.jl | 13 +++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/lib/nnlib.jl b/src/lib/nnlib.jl index edb81fb5e..32b217003 100644 --- a/src/lib/nnlib.jl +++ b/src/lib/nnlib.jl @@ -10,13 +10,22 @@ end function dselu(x) λ = oftype(x/1, 1.0507009873554804934193349852946) α = oftype(x/1, 1.6732632423543772848170429916717) - λ * ifelse(x > 0, 1, α * exp(x)) + λ * ifelse(x > 0, one(x), α * exp(x)) end +@adjoint selu(x) = selu(x), Δ -> (dselu(x) * Δ,) @adjoint function Base.Broadcast.broadcasted(::typeof(selu), x::Numeric) selu.(x), Δ -> (nothing, dselu.(x) .* Δ) end +delu(x, α) = ifelse(x ≥ 0, one(x), α * exp(x)) + +@adjoint elu(x, α) = softmax(xs, dims=dims), Δ -> (∇softmax(Δ, xs, dims=dims),) +@adjoint elu(x, α) = elu(x, α), Δ -> (delu.(x, α) .* Δ, nothing) +@adjoint function Base.Broadcast.broadcasted(::typeof(elu), x::Numeric, α::Numeric) + elu.(x, α), Δ -> (nothing, delu.(x, α) .* Δ, nothing) +end + @adjoint function σ(x::Real) y = σ(x) return y, Δ -> (Δ * y * (1 - y),) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index e120293d7..3eee44824 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -93,6 +93,15 @@ end @test gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2) +@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2) +@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2) + +# tests for https://github.com/FluxML/Zygote.jl/issues/758 +@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000]) == ([1.0507009873554805, 1.0507009873554805],) +@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,) +@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],) +@test gradient(x -> elu(x, 2), 1_000) == (1.,) +@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),) @test gradtest((x, W, b) -> tanh.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> tanh.(W*x .+ b), (5,3), (2,5), 2) @@ -1453,7 +1462,7 @@ end x = randn(Float64,16,16) @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} + @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} x = randn(Float32,16) P = plan_fft(x) @@ -1462,7 +1471,7 @@ end @test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1} x = randn(Float32,16,16) - @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} + @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2} end From 646b32990fb4b5ac2b4b93ae47e4aa0ea7ff6176 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20Ra=C4=8Dinsk=C3=BD?= Date: Thu, 13 Aug 2020 23:14:46 +0200 Subject: [PATCH 2/6] fixing some copypaster code --- src/lib/nnlib.jl | 5 ++--- test/gradcheck.jl | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lib/nnlib.jl b/src/lib/nnlib.jl index 32b217003..c8039ed94 100644 --- a/src/lib/nnlib.jl +++ b/src/lib/nnlib.jl @@ -13,15 +13,14 @@ function dselu(x) λ * ifelse(x > 0, one(x), α * exp(x)) end -@adjoint selu(x) = selu(x), Δ -> (dselu(x) * Δ,) +@adjoint selu(x::Numeric) = selu(x), Δ -> (dselu(x) * Δ,) @adjoint function Base.Broadcast.broadcasted(::typeof(selu), x::Numeric) selu.(x), Δ -> (nothing, dselu.(x) .* Δ) end delu(x, α) = ifelse(x ≥ 0, one(x), α * exp(x)) -@adjoint elu(x, α) = softmax(xs, dims=dims), Δ -> (∇softmax(Δ, xs, dims=dims),) -@adjoint elu(x, α) = elu(x, α), Δ -> (delu.(x, α) .* Δ, nothing) +@adjoint elu(x::Numeric, α::Numeric) = elu(x, α), Δ -> (delu.(x, α) .* Δ, nothing) @adjoint function Base.Broadcast.broadcasted(::typeof(elu), x::Numeric, α::Numeric) elu.(x, α), Δ -> (nothing, delu.(x, α) .* Δ, nothing) end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 3eee44824..eb7893433 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1462,7 +1462,7 @@ end x = randn(Float64,16,16) @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} + @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} x = randn(Float32,16) P = plan_fft(x) From 215a5f0ea6a3ca2b2847c9947df2424a675e6a73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20Ra=C4=8Dinsk=C3=BD?= Date: Fri, 14 Aug 2020 09:32:18 +0200 Subject: [PATCH 3/6] added gradcheck to test against numerical derivation --- test/gradcheck.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index eb7893433..bcaab30b5 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -102,6 +102,8 @@ end @test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],) @test gradient(x -> elu(x, 2), 1_000) == (1.,) @test gradient(x -> elu(x, 2), -1) == (2*exp(-1),) +@test gradcheck(x->sum(selu.(x)),[1_000, 10_000]) +@test gradcheck(x->sum(elu.(x, 3.5)),[1_000, 10_000]) @test gradtest((x, W, b) -> tanh.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> tanh.(W*x .+ b), (5,3), (2,5), 2) @@ -1462,7 +1464,7 @@ end x = randn(Float64,16,16) @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2} - @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} + @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} x = randn(Float32,16) P = plan_fft(x) From aba8c5b85bdd624da00674ba4def4dd266f77bcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20Ra=C4=8Dinsk=C3=BD?= Date: Fri, 14 Aug 2020 09:37:07 +0200 Subject: [PATCH 4/6] using floats for gradcheck --- test/gradcheck.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index bcaab30b5..b01a4d531 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -102,8 +102,8 @@ end @test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],) @test gradient(x -> elu(x, 2), 1_000) == (1.,) @test gradient(x -> elu(x, 2), -1) == (2*exp(-1),) -@test gradcheck(x->sum(selu.(x)),[1_000, 10_000]) -@test gradcheck(x->sum(elu.(x, 3.5)),[1_000, 10_000]) +@test gradcheck(x->sum(selu.(x)),[1_000., 10_000.]) +@test gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) @test gradtest((x, W, b) -> tanh.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> tanh.(W*x .+ b), (5,3), (2,5), 2) From 33bbc021e9fd44988871d99145f242ed6a40e33d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20Ra=C4=8Dinsk=C3=BD?= Date: Fri, 14 Aug 2020 10:01:50 +0200 Subject: [PATCH 5/6] added tests, comment on numerical instability --- test/gradcheck.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index b01a4d531..cc2e6f55a 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -102,8 +102,15 @@ end @test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],) @test gradient(x -> elu(x, 2), 1_000) == (1.,) @test gradient(x -> elu(x, 2), -1) == (2*exp(-1),) -@test gradcheck(x->sum(selu.(x)),[1_000., 10_000.]) -@test gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) +@test gradcheck(x->sum(selu.(x)),[100., 1_000.]) +@test gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.]) +# numerical instability even for the linear part of such function, see: +# julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.]) +# ([1.0506591796875, 1.0506591796875],) +# julia> gradient(x->sum(selu.(x)),[1_000., 10_000.]) +# ([1.0507009873554805, 1.0507009873554805],) +@test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.]) +@test_broken gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) @test gradtest((x, W, b) -> tanh.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> tanh.(W*x .+ b), (5,3), (2,5), 2) From 1b27b9c17dbffbc531600316f1b4cd798fe4cd30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20Ra=C4=8Dinsk=C3=BD?= Date: Fri, 14 Aug 2020 11:01:57 +0200 Subject: [PATCH 6/6] passing big numbers test for elu --- test/gradcheck.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index cc2e6f55a..2a7415d26 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -104,13 +104,13 @@ end @test gradient(x -> elu(x, 2), -1) == (2*exp(-1),) @test gradcheck(x->sum(selu.(x)),[100., 1_000.]) @test gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.]) +@test gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting # numerical instability even for the linear part of such function, see: # julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.]) # ([1.0506591796875, 1.0506591796875],) # julia> gradient(x->sum(selu.(x)),[1_000., 10_000.]) # ([1.0507009873554805, 1.0507009873554805],) @test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.]) -@test_broken gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) @test gradtest((x, W, b) -> tanh.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> tanh.(W*x .+ b), (5,3), (2,5), 2)