From 2ebd1f6be5a21440a490cb1a762f39c6ecff56b5 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 29 Apr 2021 12:49:00 +0100 Subject: [PATCH 1/6] implementation --- src/testers.jl | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/testers.jl b/src/testers.jl index 5505d45f..8b4e3d5c 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -114,8 +114,7 @@ function test_frule( Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...) check_equal(Ω_ad, Ω; isapprox_kwargs...) - # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 - ẋs_is_ignored = isa.(ẋs, Union{Nothing, DoesNotExist}) + ẋs_is_ignored = _ignore.(ẋs) if any(ẋs .== nothing) Base.depwarn( "test_frule(f, k ⊢ nothing) is deprecated, use " * @@ -165,6 +164,7 @@ function test_rrule( # To simplify some of the calls we make later lets group the kwargs for reuse isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...) + println("HI") @testset "test_rrule: $f at $inputs" begin _ensure_not_running_on_functor(f, "test_rrule") @@ -191,8 +191,7 @@ function test_rrule( @test ∂self === NO_FIELDS # No internal fields # Correctness testing via finite differencing. - # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 - x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing, DoesNotExist}) + x̄s_is_ignored = _ignore.(accumulated_x̄) if any(accumulated_x̄ .== nothing) Base.depwarn( "test_rrule(f, k ⊢ nothing) is deprecated, use " * @@ -201,9 +200,9 @@ function test_rrule( ) end - x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne) + x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_ignored) for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd) - if accumulated_x̄ isa Union{Nothing, DoesNotExist} # then we marked this argument as not differentiable # TODO remove once #113 + if _ignore(accumulated_x̄) # then we marked this argument as not differentiable @assert x̄_fd === nothing # this is how `_make_j′vp_call` works x̄_ad isa Zero && error( "The pullback in the rrule for $f function should use DoesNotExist()" * @@ -245,6 +244,18 @@ function _ensure_not_running_on_functor(f, name) end end +""" + _ignore(x) -> Bool + +Decides whether to ignore certain kinds of arguments for finite differencing. +""" +_ignore(::Any) = false +_ignore(::DoesNotExist) = true +_ignore(::Nothing) = true # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 +function _ignore(c::Composite) + return all([d === DoesNotExist() for d in c]) +end + """ _test_inferred(f, args...; kwargs...) From d781445cf5a858c5e2e4151066566a68762e4659 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 29 Apr 2021 12:49:09 +0100 Subject: [PATCH 2/6] test --- test/testers.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/testers.jl b/test/testers.jl index 4eb6693f..6f21ec0f 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -526,4 +526,23 @@ end test_frule(f_notimplemented, randn(), randn()) test_rrule(f_notimplemented, randn(), randn()) end + + @testset "ignore Composites fully made of DoesNotExist" begin + mygetindex(x, ind) = x[ind] + function ChainRulesCore.rrule(::typeof(mygetindex), x, ind) + function myfunc_pullback(Δy) + xgrad = zero(x) + xgrad[ind] = Δy + return NO_FIELDS, xgrad, DoesNotExist() + end + return mygetindex(x, ind), myfunc_pullback + end + + # test that Composites made only of DoesNotExist() are ignored in finite differencing + test_rrule(mygetindex, [1, 2, 3, 4, 5.0], 1:2 ⊢ Composite{UnitRange{Int64}}(start=DoesNotExist(), stop=DoesNotExist())) + + # rand_tangent also returns Composites with all DoesNotExist() + test_rrule(mygetindex, [1, 2, 3, 4, 5.0], 1:2) + test_rrule(mygetindex, [1, 2, 3, 4, 5.0], 1:2:4) + end end From 62dedecac95a7869733e3370e6d9694fee6b6727 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 29 Apr 2021 12:49:28 +0100 Subject: [PATCH 3/6] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 733172a0..b7e70ed5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.6.8" +version = "0.6.9" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From ea9c154ca03e991dae19c4fe5bdb1da205cb3028 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 29 Apr 2021 12:50:14 +0100 Subject: [PATCH 4/6] typo --- src/testers.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/testers.jl b/src/testers.jl index 8b4e3d5c..b158848a 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -164,7 +164,6 @@ function test_rrule( # To simplify some of the calls we make later lets group the kwargs for reuse isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...) - println("HI") @testset "test_rrule: $f at $inputs" begin _ensure_not_running_on_functor(f, "test_rrule") From 9de2f6e196f5a7ca452eec16b2ada5bafb2ec02a Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 29 Apr 2021 12:53:27 +0100 Subject: [PATCH 5/6] rephrase docstring --- src/testers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/testers.jl b/src/testers.jl index b158848a..3069e87d 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -246,7 +246,7 @@ end """ _ignore(x) -> Bool -Decides whether to ignore certain kinds of arguments for finite differencing. +Returns true for tangents we want to ignore in finite differencing, and false otherwise. """ _ignore(::Any) = false _ignore(::DoesNotExist) = true From 671de31c767abe93367c4698142b75323688a1bb Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 29 Apr 2021 13:02:54 +0100 Subject: [PATCH 6/6] composite to one line --- src/testers.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/testers.jl b/src/testers.jl index 3069e87d..ad02defc 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -251,9 +251,7 @@ Returns true for tangents we want to ignore in finite differencing, and false ot _ignore(::Any) = false _ignore(::DoesNotExist) = true _ignore(::Nothing) = true # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 -function _ignore(c::Composite) - return all([d === DoesNotExist() for d in c]) -end +_ignore(c::Composite) = all([d === DoesNotExist() for d in c]) """ _test_inferred(f, args...; kwargs...)