diff --git a/Project.toml b/Project.toml index 18f0776..cc71dee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TemporalGPs" uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" authors = ["willtebbutt and contributors"] -version = "0.6.3" +version = "0.6.4" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" diff --git a/src/util/chainrules.jl b/src/util/chainrules.jl index 0d2bec6..9b40b25 100644 --- a/src/util/chainrules.jl +++ b/src/util/chainrules.jl @@ -181,15 +181,20 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_ma return Fill(y_el, axes(x)), _map_Fill_rrule end -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, x::Fill, y::Fill) - z_el, back = ChainRulesCore.rrule_via_ad(config, f, x.value, y.value) +# Somehow needed to avoid the _map -> map indirection +function _map(f, xs::Fill...) + all(==(axes(first(xs))), axes.(xs)) || error("All axes should be the same") + Fill(f(FillArrays.getindex_value.(xs)...), axes(first(xs))) +end + +function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, xs::Fill...) + z_el, back = ChainRulesCore.rrule_via_ad(config, f, FillArrays.getindex_value.(xs)...) function _map_Fill_rrule(Δ) - Δf, Δx_el, Δy_el = back(unthunk(Δ).value) - return NoTangent(), Δf, Fill(Δx_el, axes(x)), Fill(Δy_el, axes(x)) + Δf, Δxs_el... = back(unthunk(Δ).value) + return NoTangent(), Δf, Fill.(Δxs_el, axes.(xs))... end - return Fill(z_el, axes(x)), _map_Fill_rrule + return Fill(z_el, axes(first(xs))), _map_Fill_rrule end - ### Same thing for `StructArray` diff --git a/test/util/chainrules.jl b/test/util/chainrules.jl index 2fd8267..0d71611 100644 --- a/test/util/chainrules.jl +++ b/test/util/chainrules.jl @@ -80,13 +80,17 @@ include("../test_util.jl") test_rrule(_map, x -> 2.0 * x, x; check_inferred=false) test_rrule(ZygoteRuleConfig(), (x,a)-> _map(x -> x * a, x), x, 2.0; check_inferred=false, rrule_f=rrule_via_ad) end - @testset "_map(f, x1::Fill, x2::Fill)" begin + @testset "_map(f, x::Fill....)" begin x1 = Fill(randn(3, 4), 3) x2 = Fill(randn(3, 4), 3) + x3 = Fill(randn(3, 4), 3) @test _map(+, x1, x2) == _map(+, collect(x1), collect(x2)) test_rrule(_map, +, x1, x2; check_inferred=true) + @test _map(+, x1, x2, x3) == _map(+, collect(x1), collect(x2), collect(x3)) + test_rrule(_map, +, x1, x2, x3; check_inferred=true) + fsin(x, y) = sin.(x .* y) test_rrule(_map, fsin, x1, x2; check_inferred=false)