From 86be30c350b8da9d520c8a026304e8aa92dc89a7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 17 Jul 2020 14:59:05 +0100 Subject: [PATCH] correct and test getindex rrule --- src/rulesets/Base/array.jl | 6 +++--- test/rulesets/Base/array.jl | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index e5cbcb388..bb3c054ca 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -105,16 +105,16 @@ end ##### getindex ##### -function rrule(::typeof(getindex), x::Array{<:Number}, inds::Union{Int, Vararg{Int}}) +function rrule(::typeof(getindex), x::Array{<:Number}, inds::Vararg{Int}) y = getindex(x, inds...) function getindex_pullback(ȳ) function getindex_add!(Δ) - Δ[inds...] .+= ȳ; + Δ[inds...] = Δ[inds...] .+ ȳ return Δ end x̄ = InplaceableThunk( - @thunk(getindex_add!(zeros(x))), + @thunk(getindex_add!(zero(x))), getindex_add! ) return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index b807bbef5..29dad39c1 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -78,7 +78,7 @@ end (ds, dv, dd) = pullback(ones(4)) @test ds === NO_FIELDS @test dd isa DoesNotExist - @test extern(dv) == 4 + @test extern(dv) == 4 y, pullback = rrule(fill, 2.0, (3, 3, 3)) @test y == fill(2.0, (3, 3, 3)) @@ -87,3 +87,17 @@ end @test dd isa DoesNotExist @test dv ≈ 27.0 end + +@testset "getindex" begin + x = [1.0 2.0 3.0; 10.0 20.0 30.0] + ind = [2,3] + ȳ = 7.2 + x̄_fd, = j′vp(ChainRulesTestUtils._fdm, a->getindex(a, ind...), ȳ, x) + y, pullback = rrule(getindex, x, ind...) + _, x̄_ad, = pullback(ȳ) + + @test unthunk(x̄_ad) ≈ x̄_fd + + x_like = x .+ 1.0 + @test x̄_ad.add!(copy(x_like)) ≈ x_like + x̄_fd +end \ No newline at end of file