Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use FiniteDifferences.jl for gradient checks #464

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 1 addition & 3 deletions .travis.yml
Expand Up @@ -15,9 +15,7 @@ git:
matrix:
allow_failures:
- julia: nightly

jobs:
include:
include:
- stage: "Documentation"
julia: 1.3
os: linux
Expand Down
28 changes: 12 additions & 16 deletions Manifest.toml
Expand Up @@ -23,9 +23,9 @@ version = "0.2.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "f784254f428fb8fd7ac15982e5862a38a44523d3"
git-tree-sha1 = "b7720de347734f4716d1815b00ce5664ed6bbfd4"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.7"
version = "0.17.9"

[[Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -61,16 +61,15 @@ version = "3.3.9+3"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "de38b0253ade98340fabaf220f368f6144541938"
pinned = true
git-tree-sha1 = "fec413d4fc547992eb62a5c544cedb6d7853c1f5"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.7.4"
version = "0.8.4"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "840700059391d36e2498d89c2e82c08f261f2a2a"
git-tree-sha1 = "88b082d492be6b63f967b6c96b352e25ced1a34c"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.8"
version = "0.10.9"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
Expand All @@ -89,7 +88,6 @@ deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[LibGit2]]
deps = ["Printf"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[Libdl]]
Expand All @@ -103,10 +101,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MKL_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "61069ae718b8ab1e325bbfb4e5268902e7ea08e3"
deps = ["IntelOpenMP_jll", "Libdl", "Pkg"]
git-tree-sha1 = "720629cc8cbd12c146ca01b661fd1a6cf66e2ff4"
uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
version = "2019.0.117+0"
version = "2019.0.117+2"

[[MacroTools]]
deps = ["DataStructures", "Markdown", "Random"]
Expand All @@ -120,11 +118,9 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[NNlib]]
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
git-tree-sha1 = "5642639793b2de824519683c336daa6bda84ff05"
repo-rev = "master"
repo-url = "https://github.com/FluxML/NNlib.jl.git"
git-tree-sha1 = "755c0bab3912ff782167e1b4b774b833f8a0e550"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.6.2"
version = "0.6.4"

[[NaNMath]]
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
Expand All @@ -144,7 +140,7 @@ uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Expand Up @@ -23,6 +23,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
DiffRules = "0.0, 0.1, 1"
FFTW = "1"
FillArrays = "0"
FiniteDifferences = "0.9"
ForwardDiff = "0"
IRTools = "0.3"
MacroTools = "0.5"
Expand All @@ -37,8 +38,9 @@ julia = "1"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Distances", "StatsFuns", "CUDAapi", "CuArrays"]
test = ["CUDAapi", "CuArrays", "Distances", "FiniteDifferences", "StatsFuns", "Test"]
98 changes: 68 additions & 30 deletions test/gradcheck.jl
Expand Up @@ -2,28 +2,51 @@ using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics, FillArrays, FFTW
using Zygote: gradient
using NNlib: conv, ∇conv_data, depthwiseconv
using Base.Broadcast: broadcast_shape
using FiniteDifferences: FiniteDifferences, central_fdm, forward_fdm, backward_fdm

function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x)
δ = sqrt(eps())
tmp = x[i]
x[i] = tmp - δ/2
y1 = f(xs...)
x[i] = tmp + δ/2
y2 = f(xs...)
x[i] = tmp
Δ[i] = (y2-y1)/δ
realdomainrange(::Any) = (-Inf, Inf)
realdomainrange(::Union{typeof.((acos,asin,atanh))...}) = (-1, 1)
realdomainrange(::typeof(acosh)) = (1, Inf)
realdomainrange(::Union{typeof.((log,sqrt,^))...}) = (0, Inf)

function default_fdm(f::F) where F
# Attempt to choose a way of finite differencing that will avoid escaping the domain
lower, upper = realdomainrange(F)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using realdomainrange here makes me a bit uneasy... if you're implementing an adjoint it seems like it'd be hard to know that you need to do this (if you do?) and how to do it right. Not any kind of deal-breaker but maybe there's some way to avoid it.

Copy link
Member Author

@oxinabox oxinabox Jan 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to do it.
realdomainrange falls back to (-Inf, Inf)
and actually in basically every test bar a few critical ones, that is already being hit because the tests use anon functions

I actually just moved this code up from below where it was defined for those few critical ones already

if lower == -Inf && upper==Inf # Ideal case
central_fdm(3, 1; adapt=0)
elseif upper == Inf
forward_fdm(3, 1; adapt=0)
elseif lower == -Inf
backward_fdm(3,1; adapt=0)
else # fallback, hopefully input is not near bounds
central_fdm(3, 1; adapt=0)
end
end

function ngradient(f, xs::AbstractArray...; fdm=default_fdm(f))
return FiniteDifferences.grad(fdm, f, xs...)
end

function gradcheck(f, xs...; kwargs...)
fin_grads = ngradient(f, xs...; kwargs...)
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
ad_grads = gradient(f, xs...)
all_correct = true
for (ii, (fin_grad, ad_grad)) in enumerate(zip(fin_grads, ad_grads))
correct = isapprox(fin_grad, ad_grad, rtol = 1e-5, atol = 1e-5)
if !correct
all_correct = false
# need to stringify arrays so they show content, rather than just type and size
@debug "gradcheck failed" f nth_partial=ii fin_grad="$fin_grad" ad_grad="$ad_grad"
end
end
return grads
return all_correct
end

gradcheck(f, xs...) =
all(isapprox.(ngradient(f, xs...),
gradient(f, xs...), rtol = 1e-5, atol = 1e-5))

gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
gradtest(f, xs::AbstractArray...; kwargs...) = gradcheck((xs...) -> sum(sin, f(xs...)), xs...; kwargs...)
# We generate random matrix with elements between 0.2 and -.7 so we are not close to any
# nondefined areas for common functions
gradtest(f, dims...; kwargs...) = gradtest(f, rand.(Float64, dims)...; kwargs...)

# utilities for using gradcheck with complex matrices
_splitreim(A) = (real(A),)
Expand Down Expand Up @@ -77,7 +100,7 @@ Random.seed!(0)
@test gradtest(x -> x', rand(5))

@test gradtest(det, (4, 4))
@test gradtest(logdet, map(x -> x*x', (rand(4, 4),))[1])
@test gradtest(logdet, map(x -> x*x' + I, (rand(4, 4),))[1])
@test gradtest(x -> logabsdet(x)[1], (4, 4))

@testset "getindex" begin
Expand Down Expand Up @@ -111,7 +134,9 @@ end
@test gradient(g, ones(3)) == ([1,0,0],)
end

@info "Still GradChecking (next is conv)"
@testset "conv: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
@info "Still GradChecking (conv: spatial_rank=$spatial_rank)"
x = rand(repeat([10], spatial_rank)..., 3, 2)
w = rand(repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
Expand All @@ -122,6 +147,8 @@ end
@test gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)
end


@info "Still GradChecking (next is pooling)"
@testset "pooling: spatial_rank=$spatial_rank" for spatial_rank in (1, 2)
x = rand(repeat([10], spatial_rank)..., 3, 2)
pdims = PoolDims(x, 2)
Expand All @@ -138,9 +165,13 @@ let
@test first(back(randn(1, 3))) isa Vector
end

@test gradtest(x -> repeat(x; inner=2), rand(5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))

@info "Still GradChecking (next is repeat)"
@testset "repeat" begin
@test gradtest(x -> repeat(x; inner=2), rand(5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
end

@test gradtest(tr, rand(4, 4))

Expand All @@ -153,7 +184,7 @@ end

@testset "circshift" begin
L = 5
for D ∈ 1:5, reps ∈ 1:5
for D ∈ 1:5, reps ∈ 1:5
x0 = zeros(ntuple(d->L, D))
g = gradient(x -> x[1], x0)[1] #Zero shift gradient
shift = ntuple(_ -> rand(-L:L), D) #Random shift
Expand Down Expand Up @@ -195,6 +226,7 @@ end
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
end

@info "Still GradChecking (next is var)"
@testset "var" begin
@test gradtest(var, rand(2, 3))
@test gradtest(x -> var(x, dims=1), rand(2, 3))
Expand Down Expand Up @@ -250,7 +282,7 @@ end
end

@testset "maximum" begin
@test gradtest(maximum, rand(2, 3))
@test gradtest(maximum, rand(2, 4))

@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
Expand Down Expand Up @@ -300,6 +332,7 @@ end
@test gradtest(pinv, C)
end

@info "Still GradChecking (next is multiplication)"
@testset "multiplication" begin
@testset "matrix-matrix" begin
rng, M, P, Q = MersenneTwister(123456), 13, 7, 11
Expand All @@ -320,6 +353,7 @@ end
end
end

@info "Still GradChecking (next is backsolve)"
@testset "backsolve" begin
rng, M, P, Q = MersenneTwister(123456), 13, 10, 9
X, Y, y = randn(rng, P, P), randn(rng, P, Q), randn(rng, P)
Expand Down Expand Up @@ -469,6 +503,8 @@ end
end
end


@info "Still GradChecking (next is Hermitian)"
@testset "Hermitian" begin
rng, P = MersenneTwister(123456), 7
Re = randn(rng, P, P)
Expand Down Expand Up @@ -576,6 +612,7 @@ end
@test gradcheck(x->lyap(x[1],x[2]),[3.1,4.6])
end

@info "Still GradChecking (next is matrix exponential)"
@testset "matrix exponential" begin
@testset "real dense" begin
rng, N = MersenneTwister(6865931), 8
Expand Down Expand Up @@ -625,8 +662,8 @@ end
_hermsymtype(::Type{<:Symmetric}) = Symmetric
_hermsymtype(::Type{<:Hermitian}) = Hermitian

function _gradtest_hermsym(f, ST, A)
gradtest(_splitreim(collect(A))...) do (args...)
function _gradtest_hermsym(f, ST, A; kwargs...)
gradtest(_splitreim(collect(A))...; kwargs...) do (args...)
B = f(ST(_joinreim(_dropimaggrad.(args)...)))
return sum(_splitreim(B))
end
Expand Down Expand Up @@ -684,11 +721,6 @@ function _randvectorin(rng, n, r)
return rand(rng, n) .* (u - l) .+ l
end

realdomainrange(::Any) = (Inf, Inf)
realdomainrange(::Union{typeof.((acos,asin,atanh))...}) = (-1, 1)
realdomainrange(::typeof(acosh)) = (1, Inf)
realdomainrange(::Union{typeof.((log,sqrt,^))...}) = (0, Inf)

function _randmatseries(rng, f, T, n, domain::Type{Real})
U = _randmatunitary(rng, T, n)
λ = _randvectorin(rng, n, realdomainrange(f))
Expand All @@ -706,6 +738,7 @@ end

_randmatseries(rng, ::typeof(atanh), T, n, domain::Type{Complex}) = nothing

@info "Still GradChecking (next is power series)"
@testset "Hermitian/Symmetric power series functions" begin
MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64})
rng, N = MersenneTwister(123), 7
Expand Down Expand Up @@ -819,6 +852,7 @@ _randmatseries(rng, ::typeof(atanh), T, n, domain::Type{Complex}) = nothing
end
end

@info "Still GradChecking (next is ^ on Symetric)"
@testset "^(::Union{Symmetric,Hermitian}, p::Integer)" begin
MTs = (Symmetric{Float64}, Symmetric{ComplexF64},
Hermitian{Float64}, Hermitian{ComplexF64})
Expand Down Expand Up @@ -851,6 +885,7 @@ end
end
end

@info "Still GradChecking (next is Distances)"
using Distances

Zygote.refresh()
Expand Down Expand Up @@ -991,6 +1026,7 @@ end
@test gradcheck(x -> muladd(x[1], x[2], x[3]), [2.0, 3.0, 5.0])
end

@info "Still GradChecking (next is StatsFuns)"
import StatsFuns

Zygote.refresh()
Expand Down Expand Up @@ -1075,6 +1111,7 @@ end
@test gradcheck(x -> sum(sum(diag.([x] .* a))), b)
end

@info "Still GradChecking (next is Buffer)"
using Zygote: Buffer

@testset "Buffer" begin
Expand Down Expand Up @@ -1161,6 +1198,7 @@ end

end

@info "Still GradChecking (next is FillArrays)"
@testset "FillArrays" begin
rng, M, N = MersenneTwister(123456), 7, 11
x, y = randn(rng), randn(rng)
Expand Down