Skip to content

Commit

Permalink
Add CategoricalVector type
Browse files Browse the repository at this point in the history
  • Loading branch information
iamed2 committed Jun 5, 2017
1 parent 69390fe commit 2565c2c
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/AxisArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ include("intervals.jl")
include("search.jl")
include("indexing.jl")
include("sortedvector.jl")
include("categoricalvector.jl")
include("combine.jl")

end
81 changes: 81 additions & 0 deletions src/categoricalvector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@

export CategoricalVector

"""
A CategoricalVector is an AbstractVector which is treated as a categorical axis regardless
of the element type. Duplicate values are not allowed but are not filtered out.
A CategoricalVector axis can be indexed with an ClosedInterval, with a value, or with a
vector of values. Use of a CategoricalVector{Tuple} axis allows indexing similar to the
hierarchical index of the Python Pandas package or the R data.table package.
In general, indexing into a CategoricalVector will be much slower than the corresponding
SortedVector or another sorted axis type, as linear search is required.
### Constructors
```julia
CategoricalVector(x::AbstractVector)
```
### Arguments
* `x::AbstractVector` : the wrapped vector
### Examples
```julia
v = CategoricalVector(collect([1; 8; 10:15]))
A = AxisArray(reshape(1:16, 8, 2), v, [:a, :b])
A[Axis{:row}(1), :]
A[Axis{:row}(10), :]
A[Axis{:row}([1, 10]), :]
## Hierarchical index example with three key levels
data = reshape(1.:40., 20, 2)
v = collect(zip([:a, :b, :c][rand(1:3,20)], [:x,:y][rand(1:2,20)], [:x,:y][rand(1:2,20)]))
A = AxisArray(data, CategoricalVector(v), [:a, :b])
A[:b, :]
A[[:a,:c], :]
A[(:a,:x), :]
A[(:a,:x,:x), :]
```
"""
immutable CategoricalVector{T} <: AbstractVector{T}
data::AbstractVector{T}
end

Base.getindex(v::CategoricalVector, idx::Int) = v.data[idx]
Base.getindex(v::CategoricalVector, idx::AbstractVector) = CategoricalVector(v.data[idx])

Base.length(v::CategoricalVector) = length(v.data)
Base.size(v::CategoricalVector) = size(v.data)
Base.size(v::CategoricalVector, i) = size(v.data, i)
Base.indices(v::CategoricalVector) = indices(v.data)

axistrait{T}(::Type{CategoricalVector{T}}) = Categorical
checkaxis(::CategoricalVector) = nothing


## Add some special indexing for CategoricalVector{Tuple}'s to achieve something like
## Panda's hierarchical indexing

axisindexes{T<:Tuple,S}(ax::Axis{S,CategoricalVector{T}}, idx) = axisindexes(ax, (idx,))

function axisindexes{T<:Tuple,S}(ax::Axis{S,CategoricalVector{T}}, idx::Tuple)
collect(filter(ax_idx->_tuple_matches(ax.val[ax_idx], idx), indices(ax.val)...))
end

function _tuple_matches(element::Tuple, idx::Tuple)
length(idx) <= length(element) || return false

for (x, y) in zip(element, idx)
x == y || return false
end

return true
end

axisindexes{T<:Tuple,S}(ax::Axis{S,CategoricalVector{T}}, idx::AbstractArray) =
vcat([axisindexes(ax, i) for i in idx]...)
17 changes: 16 additions & 1 deletion src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ end
ex = Expr(:tuple)
n = 0
for i=1:length(I)
if axistrait(I[i]) <: Categorical && i <= length(Ax.parameters)
if I[i] <: Axis
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
else
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
end
n += 1

continue
end

if I[i] <: Idx
push!(ex.args, :(I[$i]))
n += 1
Expand All @@ -243,7 +254,11 @@ end
end
n += length(I[i])
elseif i <= length(Ax.parameters)
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
if I[i] <: Axis
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
else
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
end
n += 1
else
push!(ex.args, :(error("dimension ", $i, " does not have an axis to index")))
Expand Down
21 changes: 21 additions & 0 deletions test/categoricalvector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Test CategoricalVector with a hierarchical index (indexed using Tuples)
srand(1234)
data = reshape(1.:40., 20, 2)
v = collect(zip([:a, :b, :c][rand(1:3,20)], [:x,:y][rand(1:2,20)], [:x,:y][rand(1:2,20)]))
idx = sortperm(v)
A = AxisArray(data[idx,:], CategoricalVector(v[idx]), [:a, :b])
@test A[:b, :] == A[5:12, :]
@test A[[:a,:c], :] == A[[1:4;13:end], :]
@test A[(:a,:y), :] == A[2:4, :]
@test A[(:c,:y,:y), :] == A[16:end, :]
@test AxisArrays.axistrait(axes(A)[1]) <: AxisArrays.Categorical

v = CategoricalVector(collect([1; 8; 10:15]))
@test AxisArrays.axistrait(axes(A)[1]) <: AxisArrays.Categorical
A = AxisArray(reshape(1:16, 8, 2), v, [:a, :b])
@test A[Axis{:row}(CategoricalVector([15]))] == AxisArray(reshape(A.data[8, :], 1, 2), CategoricalVector([15]), [:a, :b])
@test A[Axis{:row}(CategoricalVector([15])), 1] == AxisArray([A.data[8, 1]], CategoricalVector([15]))
@test AxisArrays.axistrait(axes(A)[1]) <: AxisArrays.Categorical

# TODO: maybe make this work? Would require removing or modifying Base.getindex(A::AxisArray, idxs::Idx...)
# @test A[CategoricalVector([15]), 1] == AxisArray([A.data[8, 1]], CategoricalVector([15]))
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ using Base.Test
include("sortedvector.jl")
end

@testset "CategoricalVector" begin
include("categoricalvector.jl")
end

@testset "Search" begin
include("search.jl")
end
Expand Down

0 comments on commit 2565c2c

Please sign in to comment.