Skip to content

Commit

Permalink
add getindex rrule
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jul 17, 2020
1 parent 026a118 commit 85dd9d9
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,25 @@ function rrule(::typeof(fill), value::Any, dims::Int...)
end
return fill(value, dims), fill_pullback
end

#####
##### getindex
#####

function rrule(::typeof(getindex), x::Array{<:Number}, inds::Union{Int, Vararg{Int}})
y = getindex(x, inds...)
function getindex_pullback(ȳ)
function getindex_add!(Δ)
Δ[inds...] .+= ȳ;
return Δ
end

= InplaceableThunk(
@thunk(getindex_add!(zeros(x))),
getindex_add!
)
return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...)
end

return y, getindex_pullback
end

0 comments on commit 85dd9d9

Please sign in to comment.