Skip to content

Commit

Permalink
Add flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
iamed2 committed Jun 5, 2017
1 parent 2565c2c commit 801b8d0
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/AxisArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ module AxisArrays

using Base: tail
using RangeArrays, IntervalSets
using Iterators
using Compat

export AxisArray, Axis, axisnames, axisvalues, axisdim, axes, atindex
export AxisArray, Axis, axisnames, axisvalues, axisdim, axes, atindex, flatten

# From IntervalSets:
export ClosedInterval, ..
Expand Down
86 changes: 86 additions & 0 deletions src/combine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,89 @@ function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
return result

end #join

function greatest_common_axis(As::AxisArray...)
length(As) == 1 && return ndims(first(As))

for (i, zip_axes) in enumerate(zip(axes.(As)...))
if !all(ax -> ax == zip_axes[1], zip_axes[2:end])
return i - 1
end
end

return minimum(map(ndims, As))
end

function flatten_array_axes(array_name, array_axes)
map(zip(repeated(array_name), product(map(Ax->Ax.val, array_axes)...))) do tup
tup_name, tup_idx = tup
return (tup_name, tup_idx...)
end
end

function flatten_axes(array_names, array_axes)
collect(chain(map(flatten_array_axes, array_names, array_axes)...))
end

"""
flatten(As::AxisArray...) -> AxisArray
flatten(last_dim::Integer, As::AxisArray...) -> AxisArray
Concatenates AxisArrays with equal leading axes into a single AxisArray.
All additional axes in any of the arrays are flattened into a single additional
CategoricalVector{Tuple} axis.
### Arguments
* `last_dim::Integer`: (optional) the greatest common dimension to share between all input
arrays. The remaining axes are flattened. If this argument is not
provided, the greatest common axis found among the input arrays is
used. All preceeding axes must also be common to each input array, at
the same dimension. Values from 0 up to one more than the minimum
number of dimensions across all input arrays are allowed.
* `As::AxisArray...`: AxisArrays to be flattened together.
"""
function flatten(As::AxisArray...; kwargs...)
gca = greatest_common_axis(As...)

return _flatten(gca, As...; kwargs...)
end

function flatten(last_dim::Integer, As::AxisArray...; kwargs...)
last_dim >= 0 || throw(ArgumentError("last_dim must be at least 0"))

if last_dim > minimum(map(ndims, As))
throw(ArgumentError(
"There must be at least $last_dim (last_dim) axes in each argument"
))
end

if last_dim > greatest_common_axis(As...)
throw(ArgumentError(
"The first $last_dim axes don't all match across all arguments"
))
end

return _flatten(last_dim, As...; kwargs...)
end

function _flatten(
last_dim::Integer,
As::AxisArray...;
array_names=1:length(As),
axis_name=nothing,
)
common_axes = axes(As[1])[1:last_dim]

if axis_name === nothing
axis_name = _defaultdimname(last_dim + 1)
elseif !isa(axis_name, Symbol)
throw(ArgumentError("axis_name must be a Symbol"))
end

new_data = cat(last_dim + 1, (view(A.data, repeated(:, last_dim + 1)...) for A in As)...)
new_axis = flatten_axes(array_names, map(A -> axes(A)[last_dim+1:end], As))

# TODO: Consider creating a SortedVector axis when all flattened axes are Dimensional
return AxisArray(new_data, common_axes..., CategoricalVector(new_axis))
end
17 changes: 17 additions & 0 deletions test/combine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,20 @@ ABdata[3:6,3:6,:,2] = Bdata
@test join(A,B,method=:left) == AxisArray(ABdata[1:4, 1:4, :, :], A.axes...)
@test join(A,B,method=:right) == AxisArray(ABdata[3:6, 3:6, :, :], B.axes...)
@test join(A,B,method=:outer) == join(A,B)

# flatten
A1 = AxisArray(A1data, Axis{:X}(1:2), Axis{:Y}(1:2))
A2 = AxisArray(reshape(A2data, size(A2data)..., 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:Z}([:foo]))

@test flatten(A1, A2; array_names=[:A1, :A2]) == AxisArray(cat(3, A1data, A2data), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:page}(CategoricalVector([(:A1,), (:A2, :foo)])))
@test flatten(A1; array_names=[:foo]) == AxisArray(reshape(A1, 2, 2, 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:page}(CategoricalVector([(:foo,)])))
@test flatten(A1; array_names=[:a], axis_name=:ax) == AxisArray(reshape(A1.data, size(A1)..., 1), axes(A1)..., Axis{:ax}(CategoricalVector([(:a,)])))

@test_throws ArgumentError flatten(-1, A1)
@test_throws ArgumentError flatten(10, A1)

A1ᵀ = transpose(A1)
@test flatten(A1, A1ᵀ) == flatten(0, A1, A1ᵀ)
@test_throws ArgumentError flatten(-1, A1, A1ᵀ)
@test_throws ArgumentError flatten(1, A1, A1ᵀ)
@test_throws ArgumentError flatten(10, A1, A1ᵀ)

0 comments on commit 801b8d0

Please sign in to comment.