Skip to content

Commit

Permalink
Merge f00e61e into 2c009d1
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Apr 28, 2019
2 parents 2c009d1 + f00e61e commit 2ead855
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,58 +3,64 @@ using FDM: jvp, j′vp
const _fdm = central_fdm(5, 1)

"""
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1))
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
# Arguments
- `f`: Function for which the `frule` should be tested.
- `x`: input at which to evaluate `f` (should generally be set randomly).
- `ẋ`: differential w.r.t. `x` (should generally be set randomly).
All keyword arguments except for `fdm` are passed to `isapprox`.
"""
function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm)
return frule_test(f, ((x, ẋ),); rtol=rtol, atol=atol, fdm=fdm)
function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
return frule_test(f, ((x, ẋ),); rtol=rtol, atol=atol, fdm=fdm, kwargs...)
end

function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm)
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
xs, ẋs = collect(zip(xẋs...))
Ω, dΩ_rule = ChainRules.frule(f, xs...)
@test f(xs...) == Ω

dΩ_ad, dΩ_fd = dΩ_rule(ẋs...), jvp(fdm, xs->f(xs...), (xs, ẋs))
@test cr_isapprox(dΩ_ad, dΩ_fd, rtol, atol)
@test isapprox(dΩ_ad, dΩ_fd; rtol=rtol, atol=atol, kwargs...)
end

"""
rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1))
rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
# Arguments
- `f`: Function to which rule should be applied.
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
- `x`: input at which to evaluate `f` (should generally be set randomly).
- `x̄`: currently accumulated adjoint (should generally be set randomly).
All keyword arguments except for `fdm` are passed to `isapprox`.
"""
function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm)
function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
# Check correctness of evaluation.
fx, dx = ChainRules.rrule(f, x)
@test fx f(x)

# Correctness testing via finite differencing.
x̄_ad, x̄_fd = dx(ȳ), j′vp(fdm, f, ȳ, x)
@test cr_isapprox(x̄_ad, x̄_fd, rtol, atol)
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)

# Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct.
test_accumulation(x̄, dx, ȳ, x̄_ad)
test_accumulation(Zero(), dx, ȳ, x̄_ad)
end

function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm)
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
# Check correctness of evaluation.
xs, x̄s = collect(zip(xx̄s...))
Ω, Δx_rules = ChainRules.rrule(f, xs...)
@test f(xs...) == Ω

# Correctness testing via finite differencing.
Δxs_ad, Δxs_fd = map(Δx_rule->Δx_rule(ȳ), Δx_rules), j′vp(fdm, f, ȳ, xs...)
@test all(map((Δx_ad, Δx_fd)->cr_isapprox(Δx_ad, Δx_fd, rtol, atol), Δxs_ad, Δxs_fd))
@test all(zip(Δxs_ad, Δxs_fd)) do (Δx_ad, Δx_fd)
isapprox(Δx_ad, Δx_fd; rtol=rtol, atol=atol, kwargs...)
end

# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
map(x̄s, Δx_rules, Δxs_ad) do x̄, Δx_rule, Δx_ad
Expand All @@ -64,20 +70,17 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
end
end

function cr_isapprox(d_ad, d_fd, rtol, atol)
return isapprox(d_ad, d_fd; rtol=rtol, atol=atol)
end
function cr_isapprox(ad::Wirtinger, fd, rtol, atol)
function Base.isapprox(ad::Wirtinger, fd; kwargs...)
error("Finite differencing with Wirtinger rules not implemented")
end
function cr_isapprox(d_ad::Casted, d_fd, rtol, atol)
return all(isapprox.(extern(d_ad), d_fd; rtol=rtol, atol=atol))
function Base.isapprox(d_ad::Casted, d_fd; kwargs...)
return all(isapprox.(extern(d_ad), d_fd; kwargs...))
end
function cr_isapprox(d_ad::DNE, d_fd, rtol, atol)
function Base.isapprox(d_ad::DNE, d_fd; kwargs...)
error("Tried to differentiate w.r.t. a DNE")
end
function cr_isapprox(d_ad::Thunk, d_fd, rtol, atol)
return isapprox(extern(d_ad), d_fd; rtol=rtol, atol=atol)
function Base.isapprox(d_ad::Thunk, d_fd; kwargs...)
return isapprox(extern(d_ad), d_fd; kwargs...)
end

function test_accumulation(x̄, dx, ȳ, partial)
Expand Down

0 comments on commit 2ead855

Please sign in to comment.