Skip to content

Commit

Permalink
handle all indexing on Arrays
Browse files Browse the repository at this point in the history
more tests
  • Loading branch information
oxinabox committed Oct 16, 2020
1 parent 10a8f0d commit 631cdf7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
12 changes: 9 additions & 3 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
##### getindex
#####

function rrule(::typeof(getindex), x::Array{<:Number}, inds::Vararg{Int})
y = getindex(x, inds...)
function rrule(::typeof(getindex), x::Array, inds...)
# removes any logical indexing, CartesianIndex etc
# leaving us just with a tuple of Int, Arrays of Int and Ranges of Int
plain_inds = Base.to_indices(x, inds)
y = getindex(x, plain_inds...)
function getindex_pullback(ȳ)
function getindex_add!(Δ)
Δ[inds...] = Δ[inds...] .+
# this a optimizes away for simple cases
for (ȳ_ii, ii) in zip(ȳ, Iterators.product(plain_inds...))
Δ[ii...] += ȳ_ii
end
return Δ
end

Expand Down
60 changes: 55 additions & 5 deletions test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,57 @@
@testset "getindex" begin
x = [1.0 2.0 3.0; 10.0 20.0 30.0]
= [1.4 2.5 3.7; 10.5 20.1 30.2]
rrule_test(getindex, 2.3, (x, x̄), (2, nothing))
rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing))
rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (2, nothing))
@testset "getindex(::Matrix{<:Number},...)" begin
x = [1.0 2.0 3.0; 10.0 20.0 30.0]
= [1.4 2.5 3.7; 10.5 20.1 30.2]
full_ȳ = [7.4 5.5 2.7; 8.5 11.1 4.2]

@testset "single element" begin
rrule_test(getindex, 2.3, (x, x̄), (2, nothing))
rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (1, nothing))
rrule_test(getindex, 2.3, (x, x̄), (2, nothing), (2, nothing))

rrule_test(getindex, 2.3, (x, x̄), (CartesianIndex(2, 3), nothing))
end

@testset "slice/index postions" begin
rrule_test(getindex, [2.3, 3.1], (x, x̄), (2:3, nothing))
rrule_test(getindex, [2.3, 3.1], (x, x̄), (3:-1:2, nothing))
rrule_test(getindex, [2.3, 3.1], (x, x̄), ([3,2], nothing))
rrule_test(getindex, [2.3, 3.1], (x, x̄), ([2,3], nothing))

rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (1:2, nothing), (2:3, nothing))
rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (:, nothing), (2:3, nothing))

rrule_test(getindex, [2.3, 3.1], (x, x̄), (2:3, nothing), (1, nothing))
rrule_test(getindex, [2.3, 3.1], (x, x̄), (1, nothing), (2:3, nothing))

rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (1:2, nothing), (2:3, nothing))
rrule_test(getindex, [2.3 3.1; 4.1 5.1], (x, x̄), (:, nothing), (2:3, nothing))


rrule_test(getindex, full_ȳ, (x, x̄), (:, nothing), (:, nothing))
rrule_test(getindex, full_ȳ[:], (x, x̄), (:, nothing))
end

@testset "masking" begin
rrule_test(getindex, full_ȳ, (x, x̄), (trues(size(x)), nothing))
rrule_test(getindex, full_ȳ[:], (x, x̄), (trues(length(x)), nothing))

mask = falses(size(x))
mask[2,3] = true
mask[1,2] = true
rrule_test(getindex, [2.3, 3.1], (x, x̄), (mask, nothing))

rrule_test(
getindex, full_ȳ[1,:], (x, x̄), ([true, false], nothing), (:, nothing)
)
end

@testset "By position with repeated elements" begin
rrule_test(getindex, [2.3, 3.1], (x, x̄), ([2, 2], nothing))
rrule_test(getindex, [2.3, 3.1, 4.1], (x, x̄), ([2, 2, 2], nothing))
rrule_test(
getindex, [2.3 3.1; 4.1 5.1], (x, x̄), ([2,2], nothing), ([3,3], nothing)
)
end
end
end

0 comments on commit 631cdf7

Please sign in to comment.