Skip to content

Commit

Permalink
correct and test getindex rrule
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jul 17, 2020
1 parent 85dd9d9 commit 86be30c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,16 @@ end
##### getindex
#####

function rrule(::typeof(getindex), x::Array{<:Number}, inds::Union{Int, Vararg{Int}})
function rrule(::typeof(getindex), x::Array{<:Number}, inds::Vararg{Int})
y = getindex(x, inds...)
function getindex_pullback(ȳ)
function getindex_add!(Δ)
Δ[inds...] .+=;
Δ[inds...] = Δ[inds...] .+
return Δ
end

= InplaceableThunk(
@thunk(getindex_add!(zeros(x))),
@thunk(getindex_add!(zero(x))),
getindex_add!
)
return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...)
Expand Down
16 changes: 15 additions & 1 deletion test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end
(ds, dv, dd) = pullback(ones(4))
@test ds === NO_FIELDS
@test dd isa DoesNotExist
@test extern(dv) == 4
@test extern(dv) == 4

y, pullback = rrule(fill, 2.0, (3, 3, 3))
@test y == fill(2.0, (3, 3, 3))
Expand All @@ -87,3 +87,17 @@ end
@test dd isa DoesNotExist
@test dv 27.0
end

@testset "getindex" begin
x = [1.0 2.0 3.0; 10.0 20.0 30.0]
ind = [2,3]
= 7.2
x̄_fd, = j′vp(ChainRulesTestUtils._fdm, a->getindex(a, ind...), ȳ, x)
y, pullback = rrule(getindex, x, ind...)
_, x̄_ad, = pullback(ȳ)

@test unthunk(x̄_ad) x̄_fd

x_like = x .+ 1.0
@test x̄_ad.add!(copy(x_like)) x_like + x̄_fd
end

0 comments on commit 86be30c

Please sign in to comment.