diff --git a/Project.toml b/Project.toml index 733172a0..b7e70ed5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.6.8" +version = "0.6.9" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/testers.jl b/src/testers.jl index 5505d45f..ad02defc 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -114,8 +114,7 @@ function test_frule( Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...) check_equal(Ω_ad, Ω; isapprox_kwargs...) - # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 - ẋs_is_ignored = isa.(ẋs, Union{Nothing, DoesNotExist}) + ẋs_is_ignored = _ignore.(ẋs) if any(ẋs .== nothing) Base.depwarn( "test_frule(f, k ⊢ nothing) is deprecated, use " * @@ -191,8 +190,7 @@ function test_rrule( @test ∂self === NO_FIELDS # No internal fields # Correctness testing via finite differencing. - # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 - x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing, DoesNotExist}) + x̄s_is_ignored = _ignore.(accumulated_x̄) if any(accumulated_x̄ .== nothing) Base.depwarn( "test_rrule(f, k ⊢ nothing) is deprecated, use " * @@ -201,9 +199,9 @@ function test_rrule( ) end - x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne) + x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_ignored) for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd) - if accumulated_x̄ isa Union{Nothing, DoesNotExist} # then we marked this argument as not differentiable # TODO remove once #113 + if _ignore(accumulated_x̄) # then we marked this argument as not differentiable @assert x̄_fd === nothing # this is how `_make_j′vp_call` works x̄_ad isa Zero && error( "The pullback in the rrule for $f function should use DoesNotExist()" * @@ -245,6 +243,16 @@ function _ensure_not_running_on_functor(f, name) end end +""" + _ignore(x) -> Bool + +Returns true for tangents we want to ignore in finite differencing, and false otherwise. +""" +_ignore(::Any) = false +_ignore(::DoesNotExist) = true +_ignore(::Nothing) = true # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 +_ignore(c::Composite) = all([d === DoesNotExist() for d in c]) + """ _test_inferred(f, args...; kwargs...) diff --git a/test/testers.jl b/test/testers.jl index 4eb6693f..6f21ec0f 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -526,4 +526,23 @@ end test_frule(f_notimplemented, randn(), randn()) test_rrule(f_notimplemented, randn(), randn()) end + + @testset "ignore Composites fully made of DoesNotExist" begin + mygetindex(x, ind) = x[ind] + function ChainRulesCore.rrule(::typeof(mygetindex), x, ind) + function myfunc_pullback(Δy) + xgrad = zero(x) + xgrad[ind] = Δy + return NO_FIELDS, xgrad, DoesNotExist() + end + return mygetindex(x, ind), myfunc_pullback + end + + # test that Composites made only of DoesNotExist() are ignored in finite differencing + test_rrule(mygetindex, [1, 2, 3, 4, 5.0], 1:2 ⊢ Composite{UnitRange{Int64}}(start=DoesNotExist(), stop=DoesNotExist())) + + # rand_tangent also returns Composites with all DoesNotExist() + test_rrule(mygetindex, [1, 2, 3, 4, 5.0], 1:2) + test_rrule(mygetindex, [1, 2, 3, 4, 5.0], 1:2:4) + end end