diff --git a/Project.toml b/Project.toml index 24328640..7d328cb0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.5.9" +version = "0.5.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/testers.jl b/src/testers.jl index 10020f26..dde9dfc9 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -70,19 +70,20 @@ function _make_j′vp_call(fdm, f, ȳ, xs, ignores) @assert length(fd) == length(arginds) for (dx, ind) in zip(fd, arginds) - args[ind] = _maybe_fix_to_composite(dx) + args[ind] = _maybe_fix_to_composite(xs[ind], dx) end return (args...,) end """ - _make_jvp_call(fdm, f, xs, ẋs, ignores) + _make_jvp_call(fdm, f, y, xs, ẋs, ignores) Call `FiniteDifferences.jvp`, with the option to ignore certain `xs`. # Arguments - `fdm::FiniteDifferenceMethod`: How to numerically differentiate `f`. - `f`: The function to differentiate. +- `y`: The primal output `y=f(xs...)` or at least something of the right type - `xs`: Inputs to `f`, such that `y = f(xs...)`. - `ẋs`: The directional derivatives of `xs` w.r.t. some real number `t`. - `ignores`: Collection of `Bool`s, the same length as `xs` and `ẋs`. @@ -91,21 +92,21 @@ Call `FiniteDifferences.jvp`, with the option to ignore certain `xs`. # Returns - `Ω̇`: Derivative of output w.r.t. `t` estimated by finite differencing. """ -function _make_jvp_call(fdm, f, xs, ẋs, ignores) +function _make_jvp_call(fdm, f, y, xs, ẋs, ignores) f2 = _wrap_function(f, xs, ignores) ignores = collect(ignores) all(ignores) && return ntuple(_->nothing, length(xs)) sigargs = zip(xs[.!ignores], ẋs[.!ignores]) - return _maybe_fix_to_composite(jvp(fdm, f2, sigargs...)) + return _maybe_fix_to_composite(y, jvp(fdm, f2, sigargs...)) end # TODO: remove after https://github.com/JuliaDiff/FiniteDifferences.jl/issues/97 # For functions which return a tuple, FD returns a tuple to represent the differential. Tuple # is not a natural differential, because it doesn't overload +, so make it a Composite. -_maybe_fix_to_composite(x::Tuple) = Composite{typeof(x)}(x...) -_maybe_fix_to_composite(x::NamedTuple) = Composite{typeof(x)}(;x...) -_maybe_fix_to_composite(x) = x +_maybe_fix_to_composite(::P, x::Tuple) where {P} = Composite{P}(x...) +_maybe_fix_to_composite(::P, x::NamedTuple) where {P} = Composite{P}(;x...) +_maybe_fix_to_composite(::Any, x) = x """ test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...) @@ -197,7 +198,7 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm ẋs_is_ignored = ẋs .== nothing # Correctness testing via finite differencing. - dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), xs, ẋs, ẋs_is_ignored) + dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored) check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...) diff --git a/test/testers.jl b/test/testers.jl index 640b7ba3..b280af05 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -386,4 +386,25 @@ end @test fails(()->rrule_test(my_identity2, 4.1, (2.2, 3.3))) end end + + + @testset "Tuple primal that is not equal to differential backing" begin + # https://github.com/JuliaMath/SpecialFunctions.jl/issues/288 + forwards_trouble(x) = (1, 2.0*x) + @scalar_rule(forwards_trouble(v), Zero(), 2.0) + frule_test(forwards_trouble, (2.5, 2.1)) + + rev_trouble((x,y)) = y + function ChainRulesCore.rrule(::typeof(rev_trouble), (x,y)::P) where P + rev_trouble_pullback(ȳ) = (NO_FIELDS, Composite{P}(Zero(), ȳ)) + return y, rev_trouble_pullback + end + rrule_test( + rev_trouble, 2.5, + ( + (3, 3.0), + Composite{Tuple{Int, Float64}}(Zero(), 1.0) + ) + ) + end end