Skip to content

Commit

Permalink
Merge pull request #163 from JuliaArrays/cjf/generated-function-fixes
Browse files Browse the repository at this point in the history
Respect generated function invariants
  • Loading branch information
timholy committed Aug 11, 2019
2 parents a5556f1 + 8fe78ec commit 2f81a7a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
7 changes: 2 additions & 5 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,10 @@ Given an AxisArray and an Axis, return the integer dimension of
the Axis within the array.
"""
axisdim(A::AxisArray, ax::Axis) = axisdim(A, typeof(ax))
@generated function axisdim(A::AxisArray, ax::Type{Ax}) where Ax<:Axis
dim = axisdim(A, Ax)
:($dim)
end
axisdim(A::AxisArray, ax::Type{Ax}) where Ax<:Axis = axisdim(typeof(A), Ax)
# The actual computation is done in the type domain, which is a little tricky
# due to type invariance.
function axisdim(::Type{AxisArray{T,N,D,Ax}}, ::Type{<:Axis{name,S} where S}) where {T,N,D,Ax,name}
@generated function axisdim(::Type{AxisArray{T,N,D,Ax}}, ::Type{<:Axis{name,S} where S}) where {T,N,D,Ax,name}
isa(name, Int) && return name <= N ? name : error("axis $name greater than array dimensionality $N")
names = axisnames(Ax)
idx = findfirst(isequal(name), names)
Expand Down
17 changes: 15 additions & 2 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,28 @@ function axisindexes(::Type{Categorical}, ax::AbstractVector, idx::AbstractVecto
res
end

# Creates *instances* of axis traits for a set of axes.
# TODO: Transition axistrait() to return trait instances in line with common
# practice in Base and other packages.
#
# This function is a utility tool to ensure that `axistrait` is only called
# from outside the generated function below. (If not, we can get world age
# errors.)
_axistraits(ax1, rest...) = (axistrait(ax1)(), _axistraits(rest...)...)
_axistraits() = ()

# This catch-all method attempts to convert any axis-specific non-standard
# indexing types to their integer or integer range equivalents using axisindexes
# It is separate from the `Base.getindex` function to allow reuse between
# set- and get- index.
@generated function to_index(A::AxisArray{T,N,D,Ax}, I...) where {T,N,D,Ax}
to_index(A::AxisArray, I...) = _to_index(A, _axistraits(I...), I...)

@generated function _to_index(A::AxisArray{T,N,D,Ax}, axtraits, I...) where {T,N,D,Ax}
ex = Expr(:tuple)
n = 0
axtrait_types = axtraits.parameters
for i=1:length(I)
if axistrait(I[i]) <: Categorical && i <= length(Ax.parameters)
if axtrait_types[i] <: Categorical && i <= length(Ax.parameters)
if I[i] <: Axis
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
else
Expand Down
10 changes: 5 additions & 5 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ B = AxisArray(reshape(1:15, 5,3), .1:.1:0.5, [:a, :b, :c])
@test @view(B[ClosedInterval(0.2, 0.6), :]) == @view(B[ClosedInterval(0.2, 0.6)]) == B[2:end,:]

# Test Categorical indexing
@test B[:, :a] == @view(B[:, :a]) == B[:,1]
@test B[:, :c] == @view(B[:, :c]) == B[:,3]
@test B[:, [:a]] == @view(B[:, [:a]]) == B[:,[1]]
@test B[:, [:c]] == @view(B[:, [:c]]) == B[:,[3]]
@test B[:, [:a,:c]] == @view(B[:, [:a,:c]]) == B[:,[1,3]]
@test @inferred(B[:, :a]) == @view(B[:, :a]) == B[:,1]
@test @inferred(B[:, :c]) == @view(B[:, :c]) == B[:,3]
@test @inferred(B[:, [:a]]) == @view(B[:, [:a]]) == B[:,[1]]
@test @inferred(B[:, [:c]]) == @view(B[:, [:c]]) == B[:,[3]]
@test @inferred(B[:, [:a,:c]]) == @view(B[:, [:a,:c]]) == B[:,[1,3]]

@test B[Axis{:row}(ClosedInterval(0.15, 0.3))] == @view(B[Axis{:row}(ClosedInterval(0.15, 0.3))]) == B[2:3,:]

Expand Down

0 comments on commit 2f81a7a

Please sign in to comment.