Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ end
If `check_inferred=true`, then the type-stability of the `frule` is checked.
All remaining keyword arguments are passed to `isapprox`.
"""
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), check_inferred=true, kwargs...)
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e-9, fdm=_fdm, fkwargs::NamedTuple=NamedTuple(), check_inferred::Bool=true, kwargs...)
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

Expand All @@ -209,7 +209,9 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm
xs = first.(xẋs)
ẋs = last.(xẋs)
check_inferred && _test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
Ω_ad, dΩ_ad = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
Ω_ad, dΩ_ad = res
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
check_equal(Ω_ad, Ω; isapprox_kwargs...)

Expand Down Expand Up @@ -241,7 +243,7 @@ If `check_inferred=true`, then the type-stability of the `rrule` and the pullbac
returns are checked.
All remaining keyword arguments are passed to `isapprox`.
"""
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, check_inferred=true, fkwargs=NamedTuple(), kwargs...)
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e-9, fdm=_fdm, check_inferred::Bool=true, fkwargs::NamedTuple=NamedTuple(), kwargs...)
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

Expand All @@ -251,7 +253,9 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
xs = first.(xx̄s)
accumulated_x̄ = last.(xx̄s)
check_inferred && _test_inferred(rrule, f, xs...; fkwargs...)
y_ad, pullback = rrule(f, xs...; fkwargs...)
res = rrule(f, xs...; fkwargs...)
res === nothing && throw(MethodError(rrule, typeof((f, xs...))))
y_ad, pullback = res
y = f(xs...; fkwargs...)
check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct

Expand Down