Skip to content

Commit

Permalink
Split getindex and view apart. Closes #38.
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Sep 19, 2016
1 parent 6ace815 commit 8b55b92
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 103 deletions.
153 changes: 57 additions & 96 deletions src/indexing.jl
@@ -1,65 +1,23 @@
### Indexing returns either a scalar or a smartly-subindexed AxisArray ###
typealias Idx Union{Real,Colon,AbstractArray{Int}}

# Limit indexing to types supported by SubArrays, at least initially
typealias Idx Union{Colon,Int,AbstractVector{Int}}
using Base: ViewIndex, linearindexing, unsafe_getindex, unsafe_setindex!

# Defer linearindexing to the wrapped array
import Base: linearindexing, unsafe_getindex, unsafe_setindex!
Base.linearindexing{T,N,D}(::AxisArray{T,N,D}) = linearindexing(D)

# Simple scalar indexing where we just set or return scalars
Base.getindex(A::AxisArray, idxs::Int...) = A.data[idxs...]
Base.setindex!(A::AxisArray, v, idxs::Int...) = (A.data[idxs...] = v)

# Default to views already
Base.getindex{T}(A::AxisArray{T,1}, idx::Colon) = A
@inline Base.getindex(A::AxisArray, idxs::Int...) = A.data[idxs...]
@inline Base.setindex!(A::AxisArray, v, idxs::Int...) = (A.data[idxs...] = v)

# Cartesian iteration
Base.eachindex(A::AxisArray) = eachindex(A.data)
Base.getindex(A::AxisArray, idx::Base.IteratorsMD.CartesianIndex) = A.data[idx]
Base.setindex!(A::AxisArray, v, idx::Base.IteratorsMD.CartesianIndex) = (A.data[idx] = v)

# More complicated cases where we must create a subindexed AxisArray
# TODO: do we want to be dogmatic about using views? For the data? For the axes?
# TODO: perhaps it would be better to return an entirely lazy SubAxisArray view
@generated function Base.getindex{T,N,D,Ax}(A::AxisArray{T,N,D,Ax}, idxs::Idx...)
newdims = length(idxs)
# If the last index is a linear indexing range that may span multiple
# dimensions in the original AxisArray, we can no longer track those axes.
droplastaxis = N > newdims && !(idxs[end] <: Real) ? 1 : 0
# Drop trailing scalar dimensions
while newdims > 0 && idxs[newdims] <: Real
newdims -= 1
end
names = axisnames(A)
axes = Expr(:tuple)
Isplat = Expr[]
reshape = false
newshape = Expr[]
for i = 1:newdims-droplastaxis
prepaxis!(axes.args, Isplat, idxs[i], names, i)
end
for i = newdims-droplastaxis+1:length(idxs)
push!(Isplat, :(idxs[$i]))
end
quote
data = view(A.data, $(Isplat...))
AxisArray(data, $axes) # TODO: avoid checking the axes here
end
end

# When we index with non-vector arrays, we *add* dimensions. This isn't
# supported by SubArray currently, so we instead return a copy.
# TODO: we probably shouldn't hack Base like this, but it's so convenient...
if VERSION < v"0.5.0-dev"
@inline Base.index_shape_dim(A, dim, i::AbstractArray{Bool}, I...) = (sum(i), Base.index_shape_dim(A, dim+1, I...)...)
@inline Base.index_shape_dim(A, dim, i::AbstractArray, I...) = (size(i)..., Base.index_shape_dim(A, dim+1, I...)...)
end
@generated function Base.getindex(A::AxisArray, I::Union{Idx, AbstractArray{Int}}...)
@generated function reaxis(A::AxisArray, I::Idx...)
N = length(I)
Isplat = [:(I[$d]) for d=1:N]
# Determine the new axes:
# Like above, drop linear indexing over multiple axes
# Drop linear indexing over multiple axes
droplastaxis = ndims(A) > N && !(I[end] <: Real) ? 1 : 0
# Drop trailing scalar dimensions
lastnonscalar = N
Expand All @@ -70,44 +28,74 @@ end
newaxes = Expr[]
for d=1:lastnonscalar-droplastaxis
if I[d] <: AxisArray
# Indexing with an AxisArray joins the axis names
idxnames = axisnames(I[d])
for i=1:ndims(I[d])
push!(newaxes, :($(Axis{Symbol(names[d], "_", idxnames[i])})(I[$d].axes[$i].val)))
end
elseif I[d] <: Idx
push!(newaxes, :($(Axis{names[d]})(A.axes[$d].val[J[$d]])))
elseif I[d] <: Real
elseif I[d] <: Union{AbstractVector,Colon}
push!(newaxes, :($(Axis{names[d]})(A.axes[$d].val[Base.to_index(I[$d])])))
elseif I[d] <: AbstractArray
for i=1:ndims(I[d])
# When we index with non-vector arrays, we *add* dimensions.
push!(newaxes, :($(Axis{Symbol(names[d], "_", i)})(indices(I[$d], $i))))
end
end
end
quote
# First copy the data using scalar indexing - an adaptation of Base
checkbounds(A, I...)
J = Base.to_indexes($(Isplat...))
sz = Base.index_shape(A, J...)
idx_lens = Base.index_lengths(A, J...)
src = A.data
dest = similar(A.data, sz)
D = eachindex(dest)
Ds = start(D)
Base.Cartesian.@nloops $N i d->(1:idx_lens[d]) d->(@inbounds j_d = J[d][i_d]) begin
d, Ds = next(D, Ds)
v = Base.Cartesian.@ncall $N unsafe_getindex src j
unsafe_setindex!(dest, v, d)
end
# And now create the AxisArray:
AxisArray(dest, $(newaxes...))
($(newaxes...),)
end
end

@inline function Base.getindex(A::AxisArray, idxs::Idx...)
AxisArray(A.data[idxs...], reaxis(A, idxs...))
end

# To resolve ambiguities, we need several definitions
if VERSION >= v"0.6.0-dev.672"
using Base.AbstractCartesianIndex
Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
else
@inline function Base.view{T,N}(A::AxisArray{T,N}, idxs::Vararg{Idx,N})
AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
end
function Base.view(A::AxisArray, idx::Idx)
AxisArray(view(A.data, idx), reaxis(A, idx))
end
@inline function Base.view{N}(A::AxisArray, idxs::Vararg{Idx,N})
# this should eventually be deleted, see julia #14770
AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
end
end

# Setindex is so much simpler. Just assign it to the data:
Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v)
@inline Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v)

### Fancier indexing capabilities provided only by AxisArrays ###
Base.getindex(A::AxisArray, idxs...) = A[to_index(A,idxs...)...]
Base.setindex!(A::AxisArray, v, idxs...) = (A[to_index(A,idxs...)...] = v)
@inline Base.getindex(A::AxisArray, idxs...) = A[to_index(A,idxs...)...]
@inline Base.setindex!(A::AxisArray, v, idxs...) = (A[to_index(A,idxs...)...] = v)
# Deal with lots of ambiguities here
if VERSION >= v"0.6.0-dev.672"
Base.view(A::AxisArray, idxs::ViewIndex...) = view(A, to_index(A,idxs...)...)
Base.view(A::AxisArray, idxs::Union{ViewIndex,AbstractCartesianIndex}...) = view(A, to_index(A,Base.IteratorsMD.flatten(idxs)...)...)
Base.view(A::AxisArray, idxs...) = view(A, to_index(A,idxs...)...)
else
for T in (:ViewIndex, :Any)
@eval begin
@inline function Base.view{T,N}(A::AxisArray{T,N}, idxs::Vararg{$T,N})
view(A, to_index(A,idxs...)...)
end
function Base.view(A::AxisArray, idx::$T)
view(A, to_index(A,idx)...)
end
@inline function Base.view{N}(A::AxisArray, idsx::Vararg{$T,N})
# this should eventually be deleted, see julia #14770
view(A, to_index(A,idxs...)...)
end
end
end
end

# First is indexing by named axis. We simply sort the axes and re-dispatch.
# When indexing by named axis the shapes of omitted dimensions are preserved
Expand Down Expand Up @@ -214,30 +202,3 @@ end
meta = Expr(:meta, :inline)
return :($meta; $ex)
end

function prepaxis!{I<:Union{AbstractVector,Colon}}(axesargs, Isplat, ::Type{I}, names, i)
idx = :(idxs[$i])
push!(axesargs, :($(Axis{names[i]})(A.axes[$i].val[$idx])))
push!(Isplat, :(idxs[$i]))
axesargs, Isplat
end
function prepaxis!{I<:AxisArray}(axesargs, Isplat, ::Type{I}, names, i)
idxnames = axisnames(I)
push!(axesargs, :($(Axis{Symbol(names[i], "_", idxnames[1])})(idxs[$i].axes[1].val)))
push!(Isplat, :(idxs[$i]))
axesargs, Isplat
end
# For anything scalar-like
if VERSION < v"0.5.0-dev"
function prepaxis!{I}(axesargs, Isplat, ::Type{I}, names, i)
idx = :(idxs[$i]:idxs[$i])
push!(axesargs, :($(Axis{names[i]})(A.axes[$i].val[$idx])))
push!(Isplat, idx)
axesargs, Isplat
end
else
function prepaxis!{I}(axesargs, Isplat, ::Type{I}, names, i)
push!(Isplat, :(idxs[$i]))
axesargs, Isplat
end
end
22 changes: 15 additions & 7 deletions test/indexing.jl
Expand Up @@ -9,15 +9,18 @@ D[1,1,1,1,1] = 10
@test A[:,:,:] == A[Axis{:row}(:)] == A[Axis{:col}(:)] == A[Axis{:page}(:)] == A.data[:,:,:]
# Test UnitRange slices
@test A[1:2,:,:] == A.data[1:2,:,:] == A[Axis{:row}(1:2)] == A[Axis{1}(1:2)] == A[Axis{:row}(ClosedInterval(-Inf,Inf))] == A[[true,true],:,:]
@test @view(A[1:2,:,:]) == A.data[1:2,:,:] == @view(A[Axis{:row}(1:2)]) == @view(A[Axis{1}(1:2)]) == @view(A[Axis{:row}(ClosedInterval(-Inf,Inf))]) == @view(A[[true,true],:,:])
@test A[:,1:2,:] == A.data[:,1:2,:] == A[Axis{:col}(1:2)] == A[Axis{2}(1:2)] == A[Axis{:col}(ClosedInterval(0.0, .25))] == A[:,[true,true,false],:]
@test @view(A[:,1:2,:]) == A.data[:,1:2,:] == @view(A[Axis{:col}(1:2)]) == @view(A[Axis{2}(1:2)]) == @view(A[Axis{:col}(ClosedInterval(0.0, .25))]) == @view(A[:,[true,true,false],:])
@test A[:,:,1:2] == A.data[:,:,1:2] == A[Axis{:page}(1:2)] == A[Axis{3}(1:2)] == A[Axis{:page}(ClosedInterval(-1., .22))] == A[:,:,[true,true,false,false]]
@test @view(A[:,:,1:2]) == @view(A.data[:,:,1:2]) == @view(A[Axis{:page}(1:2)]) == @view(A[Axis{3}(1:2)]) == @view(A[Axis{:page}(ClosedInterval(-1., .22))]) == @view(A[:,:,[true,true,false,false]])
# Test scalar slices
@test A[2,:,:] == A.data[2,:,:] == A[Axis{:row}(2)]
@test A[:,2,:] == A.data[:,2,:] == A[Axis{:col}(2)]
@test A[:,:,2] == A.data[:,:,2] == A[Axis{:page}(2)]

# Test fallback methods
@test A[[1 2; 3 4]] == A.data[[1 2; 3 4]]
@test A[[1 2; 3 4]] == @view(A[[1 2; 3 4]]) == A.data[[1 2; 3 4]]
@test A[] == A.data[]

# Test axis restrictions
Expand Down Expand Up @@ -45,14 +48,19 @@ B = AxisArray(reshape(1:15, 5,3), .1:.1:0.5, [:a, :b, :c])
@test B[ClosedInterval(0.15, 0.3), :] == B[ClosedInterval(0.15, 0.3)] == B[2:3,:]
@test B[ClosedInterval(0.2, 0.5), :] == B[ClosedInterval(0.2, 0.5)] == B[2:end,:]
@test B[ClosedInterval(0.2, 0.6), :] == B[ClosedInterval(0.2, 0.6)] == B[2:end,:]
@test @view(B[ClosedInterval(0.0, 0.5), :]) == @view(B[ClosedInterval(0.0, 0.5)]) == B[:,:]
@test @view(B[ClosedInterval(0.0, 0.3), :]) == @view(B[ClosedInterval(0.0, 0.3)]) == B[1:3,:]
@test @view(B[ClosedInterval(0.15, 0.3), :]) == @view(B[ClosedInterval(0.15, 0.3)]) == B[2:3,:]
@test @view(B[ClosedInterval(0.2, 0.5), :]) == @view(B[ClosedInterval(0.2, 0.5)]) == B[2:end,:]
@test @view(B[ClosedInterval(0.2, 0.6), :]) == @view(B[ClosedInterval(0.2, 0.6)]) == B[2:end,:]

# Test Categorical indexing
@test B[:, :a] == B[:,1]
@test B[:, :c] == B[:,3]
@test B[:, [:a]] == B[:,[1]]
@test B[:, [:a,:c]] == B[:,[1,3]]
@test B[:, :a] == @view(B[:, :a]) == B[:,1]
@test B[:, :c] == @view(B[:, :c]) == B[:,3]
@test B[:, [:a]] == @view(B[:, [:a]]) == B[:,[1]]
@test B[:, [:a,:c]] == @view(B[:, [:a,:c]]) == B[:,[1,3]]

@test B[Axis{:row}(ClosedInterval(0.15, 0.3))] == B[2:3,:]
@test B[Axis{:row}(ClosedInterval(0.15, 0.3))] == @view(B[Axis{:row}(ClosedInterval(0.15, 0.3))]) == B[2:3,:]

A = AxisArray(reshape(1:256, 4,4,4,4), Axis{:d1}(.1:.1:.4), Axis{:d2}(1//10:1//10:4//10), Axis{:d3}(["1","2","3","4"]), Axis{:d4}([:a, :b, :c, :d]))
ax1 = axes(A)[1]
Expand All @@ -68,7 +76,7 @@ A = AxisArray(reshape(1:32, 2, 2, 2, 2, 2), .1:.1:.2, .1:.1:.2, .1:.1:.2, [:a, :

# Test vectors
v = AxisArray(collect(.1:.1:10.0), .1:.1:10.0)
@test v[Colon()] === v
@test v[Colon()] == v
@test v[:] == v.data[:] == v[Axis{:row}(:)]
@test v[3:8] == v.data[3:8] == v[ClosedInterval(.25,.85)] == v[Axis{:row}(3:8)] == v[Axis{:row}(ClosedInterval(.22,.88))]

Expand Down

0 comments on commit 8b55b92

Please sign in to comment.