Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use FiniteDifferences.jl for gradient checks #464

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions .travis.yml
Expand Up @@ -25,6 +25,10 @@ jobs:
- julia --color=yes --project=docs/ docs/make.jl
after_success: skip

script:
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
- travis_wait 30 julia -e 'using Pkg; Pkg.build(); Pkg.test(coverage=true)'

## uncomment and modify the following lines to manually install system packages
#addons:
# apt: # apt-get for linux
Expand Down
28 changes: 12 additions & 16 deletions Manifest.toml
Expand Up @@ -23,9 +23,9 @@ version = "0.2.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "f784254f428fb8fd7ac15982e5862a38a44523d3"
git-tree-sha1 = "b7720de347734f4716d1815b00ce5664ed6bbfd4"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.7"
version = "0.17.9"

[[Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -61,16 +61,15 @@ version = "3.3.9+3"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "de38b0253ade98340fabaf220f368f6144541938"
pinned = true
git-tree-sha1 = "fec413d4fc547992eb62a5c544cedb6d7853c1f5"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.7.4"
version = "0.8.4"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "840700059391d36e2498d89c2e82c08f261f2a2a"
git-tree-sha1 = "88b082d492be6b63f967b6c96b352e25ced1a34c"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.8"
version = "0.10.9"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
Expand All @@ -89,7 +88,6 @@ deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[LibGit2]]
deps = ["Printf"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[Libdl]]
Expand All @@ -103,10 +101,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MKL_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "61069ae718b8ab1e325bbfb4e5268902e7ea08e3"
deps = ["IntelOpenMP_jll", "Libdl", "Pkg"]
git-tree-sha1 = "720629cc8cbd12c146ca01b661fd1a6cf66e2ff4"
uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
version = "2019.0.117+0"
version = "2019.0.117+2"

[[MacroTools]]
deps = ["DataStructures", "Markdown", "Random"]
Expand All @@ -120,11 +118,9 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[NNlib]]
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
git-tree-sha1 = "5642639793b2de824519683c336daa6bda84ff05"
repo-rev = "master"
repo-url = "https://github.com/FluxML/NNlib.jl.git"
git-tree-sha1 = "755c0bab3912ff782167e1b4b774b833f8a0e550"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.6.2"
version = "0.6.4"

[[NaNMath]]
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
Expand All @@ -144,7 +140,7 @@ uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Expand Up @@ -23,6 +23,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
DiffRules = "0.0, 0.1, 1"
FFTW = "1"
FillArrays = "0"
FiniteDifferences = "0.9"
ForwardDiff = "0"
IRTools = "0.3"
MacroTools = "0.5"
Expand All @@ -37,8 +38,9 @@ julia = "1"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Distances", "StatsFuns", "CUDAapi", "CuArrays"]
test = ["CUDAapi", "CuArrays", "Distances", "FiniteDifferences", "StatsFuns", "Test"]
73 changes: 46 additions & 27 deletions test/gradcheck.jl
Expand Up @@ -2,28 +2,52 @@ using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics, FillArrays, FFTW
using Zygote: gradient
using NNlib: conv, ∇conv_data, depthwiseconv
using Base.Broadcast: broadcast_shape
using FiniteDifferences: FiniteDifferences, central_fdm, forward_fdm, backward_fdm

function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x)
δ = sqrt(eps())
tmp = x[i]
x[i] = tmp - δ/2
y1 = f(xs...)
x[i] = tmp + δ/2
y2 = f(xs...)
x[i] = tmp
Δ[i] = (y2-y1)/δ
realdomainrange(::Any) = (-Inf, Inf)
realdomainrange(::Union{typeof.((acos,asin,atanh))...}) = (-1, 1)
realdomainrange(::typeof(acosh)) = (1, Inf)
realdomainrange(::Union{typeof.((log,sqrt,^))...}) = (0, Inf)

function default_fdm(f::F) where F
# Attempt to choose a way of finite differencing that will avoid escaping the domain
lower, upper = realdomainrange(F)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using realdomainrange here makes me a bit uneasy... if you're implementing an adjoint it seems like it'd be hard to know that you need to do this (if you do?) and how to do it right. Not any kind of deal-breaker but maybe there's some way to avoid it.

Copy link
Member Author

@oxinabox oxinabox Jan 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to do it.
realdomainrange falls back to (-Inf, Inf)
and actually in basically every test bar a few critical ones, that is already being hit because the tests use anon functions

I actually just moved this code up from below where it was defined for those few critical ones already

if lower == -Inf && upper==Inf # Ideal case
central_fdm(5, 1)
elseif upper == Inf
forward_fdm(5, 1)
elseif lower == -Inf
backward_fdm(5,1)
else # fallback, hopefully input is not near bounds
central_fdm(5, 1)
end
end

function ngradient(f, xs::AbstractArray...; fdm=default_fdm(f))
println(" ") # make sure TravisCI doesn't time out
return FiniteDifferences.grad(fdm, f, xs...)
end

function gradcheck(f, xs...; kwargs...)
fin_grads = ngradient(f, xs...; kwargs...)
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
ad_grads = gradient(f, xs...)
all_correct = true
for (ii, (fin_grad, ad_grad)) in enumerate(zip(fin_grads, ad_grads))
correct = isapprox(fin_grad, ad_grad, rtol = 1e-5, atol = 1e-5)
if !correct
all_correct = false
# need to stringify arrays so they show content, rather than just type and size
@debug "gradcheck failed" f nth_partial=ii fin_grad="$fin_grad" ad_grad="$ad_grad"
end
end
return grads
return all_correct
end

gradcheck(f, xs...) =
all(isapprox.(ngradient(f, xs...),
gradient(f, xs...), rtol = 1e-5, atol = 1e-5))

gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
gradtest(f, xs::AbstractArray...; kwargs...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kwargs...)
# We generate random matrix with elements between 0.2 and -.7 so we are not close to any
# nondefined areas for common functions
gradtest(f, dims...; kwargs...) = gradtest(f, rand.(Float64, dims)...; kwargs...)

# utilities for using gradcheck with complex matrices
_splitreim(A) = (real(A),)
Expand Down Expand Up @@ -77,7 +101,7 @@ Random.seed!(0)
@test gradtest(x -> x', rand(5))

@test gradtest(det, (4, 4))
@test gradtest(logdet, map(x -> x*x', (rand(4, 4),))[1])
@test gradtest(logdet, map(x -> x*x' + I, (rand(4, 4),))[1])
@test gradtest(x -> logabsdet(x)[1], (4, 4))

@testset "getindex" begin
Expand Down Expand Up @@ -153,7 +177,7 @@ end

@testset "circshift" begin
L = 5
for D ∈ 1:5, reps ∈ 1:5
for D ∈ 1:5, reps ∈ 1:5
x0 = zeros(ntuple(d->L, D))
g = gradient(x -> x[1], x0)[1] #Zero shift gradient
shift = ntuple(_ -> rand(-L:L), D) #Random shift
Expand Down Expand Up @@ -250,7 +274,7 @@ end
end

@testset "maximum" begin
@test gradtest(maximum, rand(2, 3))
@test gradtest(maximum, rand(2, 4))

@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
Expand Down Expand Up @@ -625,8 +649,8 @@ end
_hermsymtype(::Type{<:Symmetric}) = Symmetric
_hermsymtype(::Type{<:Hermitian}) = Hermitian

function _gradtest_hermsym(f, ST, A)
gradtest(_splitreim(collect(A))...) do (args...)
function _gradtest_hermsym(f, ST, A; kwargs...)
gradtest(_splitreim(collect(A))...; kwargs...) do (args...)
B = f(ST(_joinreim(_dropimaggrad.(args)...)))
return sum(_splitreim(B))
end
Expand Down Expand Up @@ -684,11 +708,6 @@ function _randvectorin(rng, n, r)
return rand(rng, n) .* (u - l) .+ l
end

realdomainrange(::Any) = (Inf, Inf)
realdomainrange(::Union{typeof.((acos,asin,atanh))...}) = (-1, 1)
realdomainrange(::typeof(acosh)) = (1, Inf)
realdomainrange(::Union{typeof.((log,sqrt,^))...}) = (0, Inf)

function _randmatseries(rng, f, T, n, domain::Type{Real})
U = _randmatunitary(rng, T, n)
λ = _randvectorin(rng, n, realdomainrange(f))
Expand Down